Skip to content

Commit 01edcd4

Browse files
ezyangpytorchmergebot
authored andcommitted
Make distributed modules importable even when backend not built (pytorch#159889)
This PR is greatly simplified now that it stacked on top of a PR that builds with distributed always. We only need to stub functions that may not be defined due to a backend not being enabled. Signed-off-by: Edward Yang <[email protected]> Pull Request resolved: pytorch#159889 Approved by: https://github.com/wconstab ghstack dependencies: pytorch#160449
1 parent de893e9 commit 01edcd4

File tree

21 files changed

+641
-235
lines changed

21 files changed

+641
-235
lines changed

.ci/pytorch/macos-test.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ if [[ ! $(python -c "import torch; print(int(torch.backends.openmp.is_available(
1313
fi
1414
popd
1515

16+
python -mpip install -r requirements.txt
17+
1618
# enable debug asserts in serialization
1719
export TORCH_SERIALIZATION_DEBUG=1
1820

test/distributed/tensor/test_fake.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates
2+
# Owner(s): ["oncall: distributed"]
3+
4+
import torch
5+
from torch._subclasses.fake_tensor import FakeTensorMode
6+
from torch.distributed.tensor import DTensor
7+
from torch.distributed.tensor.placement_types import Shard
8+
from torch.testing._internal.common_utils import run_tests, TestCase
9+
from torch.testing._internal.distributed.fake_pg import FakeStore
10+
11+
12+
class TestFakeDTensor(TestCase):
13+
def test_fake_dtensor_operations(self):
14+
# Use FakeTensorMode to handle CUDA tensors without actual CUDA
15+
fake_mode = FakeTensorMode()
16+
world_size = 4
17+
18+
fake_store = FakeStore()
19+
torch.distributed.init_process_group(
20+
"fake", store=fake_store, rank=0, world_size=world_size
21+
)
22+
device_mesh = torch.distributed.device_mesh.init_device_mesh(
23+
"cuda",
24+
(2, world_size // 2),
25+
)
26+
27+
# Create fake CUDA tensor using FakeTensorMode
28+
with fake_mode:
29+
x = torch.randn(1, 1, device="cuda")
30+
x = DTensor.from_local(x, device_mesh, [Shard(0), Shard(1)])
31+
32+
# Test basic DTensor operations
33+
self.assertIsInstance(x, DTensor)
34+
35+
# Test sum operation
36+
r = x.sum(1)
37+
self.assertIsInstance(r, DTensor)
38+
39+
40+
if __name__ == "__main__":
41+
run_tests()

test/test_numa_binding.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from dataclasses import dataclass
88
from multiprocessing.context import SpawnProcess
99
from typing import Any, Optional
10-
from unittest import skipUnless
10+
from unittest import skipIf, skipUnless
1111
from unittest.mock import mock_open, patch
1212

1313
import torch
@@ -22,7 +22,7 @@
2222
AffinityMode,
2323
NumaOptions,
2424
)
25-
from torch.testing._internal.common_utils import run_tests, TestCase
25+
from torch.testing._internal.common_utils import IS_MACOS, run_tests, TestCase
2626

2727

2828
@dataclass(frozen=True)
@@ -680,6 +680,7 @@ def test_core_complex_tiebreak_prefers_lower_cache_key(self) -> None:
680680
set(range(0, 2)),
681681
)
682682

683+
@skipIf(IS_MACOS, "sched_getaffinity doesn't exist")
683684
def test_binds_to_node_0_if_node_stored_as_minus_one(self) -> None:
684685
self._add_mock_hardware(
685686
num_sockets=1,

torch/_C/_distributed_c10d.pyi

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,3 +851,12 @@ class ProcessGroupXCCL(Backend):
851851

852852
def _set_process_group(pg: ProcessGroup) -> None: ...
853853
def _current_process_group() -> ProcessGroup: ...
854+
def _dump_nccl_trace_json(
855+
includeCollectives: Optional[bool] = ...,
856+
onlyActive: Optional[bool] = ...,
857+
) -> bytes: ...
858+
def _dump_nccl_trace(
859+
includeCollectives: Optional[bool] = ...,
860+
includeStackTraces: Optional[bool] = ...,
861+
onlyActive: Optional[bool] = ...,
862+
) -> bytes: ...

torch/distributed/_C_stubs.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# mypy: allow-untyped-defs
2+
"""
3+
Python stubs for backend-specific distributed components.
4+
5+
Since _C._distributed_c10d always exists now, this module only provides
6+
stubs for backend-specific functionality that may not be available in all builds
7+
(e.g., NCCL, UCC, MPI, Gloo, etc.).
8+
"""
9+
10+
from __future__ import annotations
11+
12+
from typing import Optional, TYPE_CHECKING
13+
14+
from torch._C._distributed_c10d import Store
15+
16+
17+
if TYPE_CHECKING:
18+
from datetime import timedelta
19+
20+
import torch
21+
22+
23+
# Store classes
24+
class HashStore(Store):
25+
"""Stub HashStore for builds without this functionality."""
26+
27+
def __init__(self, *args, **kwargs):
28+
self._data = {}
29+
30+
def set(self, key: str, value: str):
31+
self._data[key] = value
32+
33+
def get(self, key: str) -> bytes:
34+
return self._data.get(key, "").encode()
35+
36+
37+
# Backend-specific process group stubs
38+
class ProcessGroupMPI:
39+
"""Stub ProcessGroupMPI for non-MPI builds."""
40+
41+
def __init__(self, *args, **kwargs):
42+
pass
43+
44+
45+
class ProcessGroupNCCL:
46+
"""Stub ProcessGroupNCCL for non-NCCL builds."""
47+
48+
def __init__(self, *args, **kwargs):
49+
pass
50+
51+
52+
class ProcessGroupGloo:
53+
"""Stub ProcessGroupGloo for non-Gloo builds."""
54+
55+
def __init__(self, *args, **kwargs):
56+
pass
57+
58+
59+
class ProcessGroupUCC:
60+
"""Stub ProcessGroupUCC for non-UCC builds."""
61+
62+
def __init__(self, *args, **kwargs):
63+
pass
64+
65+
66+
class ProcessGroupXCCL:
67+
"""Stub ProcessGroupXCCL for non-XCCL builds."""
68+
69+
def __init__(self, *args, **kwargs):
70+
pass
71+
72+
73+
class _ProcessGroupWrapper:
74+
"""Stub _ProcessGroupWrapper for non-Gloo builds."""
75+
76+
def __init__(self, process_group, *args, **kwargs):
77+
self._process_group = process_group
78+
79+
def __getattr__(self, name):
80+
return getattr(self._process_group, name)
81+
82+
83+
# NCCL-specific function stubs
84+
_DEFAULT_PG_NCCL_TIMEOUT: Optional[timedelta] = None
85+
86+
87+
def _hash_tensors(tensors):
88+
"""Stub function to hash tensors - returns dummy hash."""
89+
return 0
90+
91+
92+
def _dump_nccl_trace_json(
93+
includeCollectives: Optional[bool] = None, onlyActive: Optional[bool] = None
94+
) -> bytes:
95+
"""Stub function that returns empty JSON trace."""
96+
return b"{}"
97+
98+
99+
def _dump_nccl_trace(
100+
includeCollectives: Optional[bool] = None,
101+
includeStackTraces: Optional[bool] = None,
102+
onlyActive: Optional[bool] = None,
103+
) -> bytes:
104+
"""Stub function that returns empty pickle trace."""
105+
return b""
106+
107+
108+
# NVSHMEM/SymmetricMemory stubs
109+
def _is_nvshmem_available() -> bool:
110+
"""Stub function that returns False indicating NVSHMEM is not available."""
111+
return False
112+
113+
114+
def _nvshmemx_cumodule_init(module: int) -> None:
115+
"""Stub function for NVSHMEM CU module initialization."""
116+
117+
118+
class _SymmetricMemory:
119+
"""Stub _SymmetricMemory class for builds without this functionality."""
120+
121+
def __init__(self, *args, **kwargs):
122+
pass
123+
124+
@classmethod
125+
def empty_strided_p2p(cls, size, stride, dtype, device, group_name=None):
126+
"""Stub that returns a regular tensor."""
127+
return torch.empty(size, dtype=dtype, device=device)
128+
129+
@classmethod
130+
def rendezvous(cls, tensor, group_name=None):
131+
"""Stub that returns None."""
132+
return None
133+
134+
@classmethod
135+
def set_group_info(cls, *args, **kwargs):
136+
"""Stub that does nothing."""
137+
138+
@classmethod
139+
def set_backend(cls, name):
140+
"""Stub that does nothing."""
141+
142+
@classmethod
143+
def get_backend(cls, device):
144+
"""Stub that returns None."""
145+
return None
146+
147+
@classmethod
148+
def has_multicast_support(cls, device_type, device_index):
149+
"""Stub that returns False."""
150+
return False

0 commit comments

Comments
 (0)