|
| 1 | +# Owner(s): ["oncall: distributed"] |
| 2 | + |
| 3 | +import os |
| 4 | +import sys |
| 5 | +import tempfile |
| 6 | +import threading |
| 7 | +from functools import partial, wraps |
| 8 | + |
| 9 | +import torch |
| 10 | +import torch.distributed as dist |
| 11 | +import torch.distributed._hooks as dhooks |
| 12 | + |
| 13 | +if not dist.is_available(): |
| 14 | + print("torch.distributed not available, skipping tests", file=sys.stderr) |
| 15 | + sys.exit(0) |
| 16 | + |
| 17 | + |
| 18 | +from torch.testing._internal.common_distributed import ( |
| 19 | + MultiProcessTestCase, |
| 20 | + skip_if_lt_x_gpu, |
| 21 | +) |
| 22 | + |
| 23 | +from torch.testing._internal.common_utils import run_tests, TestCase |
| 24 | + |
| 25 | + |
| 26 | +class PgHooks(MultiProcessTestCase): |
| 27 | + @property |
| 28 | + def world_size(self) -> int: |
| 29 | + return 4 |
| 30 | + |
| 31 | + def setUp(self) -> None: |
| 32 | + super().setUp() |
| 33 | + self._spawn_processes() |
| 34 | + |
| 35 | + def tearDown(self): |
| 36 | + super().tearDown() |
| 37 | + try: |
| 38 | + os.remove(self.file_name) |
| 39 | + except OSError: |
| 40 | + pass |
| 41 | + |
| 42 | + def test_pg_hook(self): |
| 43 | + pgs = [] |
| 44 | + |
| 45 | + def pg_hook(pg, pg_name): |
| 46 | + pgs.append((pg, pg_name)) |
| 47 | + |
| 48 | + dhooks.register_process_group_hook(pg_hook) |
| 49 | + dist.init_process_group( |
| 50 | + backend="gloo", |
| 51 | + rank=self.rank, |
| 52 | + world_size=self.world_size, |
| 53 | + store=dist.FileStore(self.file_name, self.world_size), |
| 54 | + ) |
| 55 | + self.assertEqual(len(pgs), 1) |
| 56 | + self.assertEqual(pgs[0][0], dist.group.WORLD) |
| 57 | + |
| 58 | + # create two partial world PGs |
| 59 | + pg0 = dist.new_group(ranks=[0, 1]) |
| 60 | + pg1 = dist.new_group(ranks=[2, 3]) |
| 61 | + |
| 62 | + # Each rank only observe two PGs being created: the default PG and one covering its ranks |
| 63 | + # We don't emit events for PG creation if the current rank doesn't belong to it. |
| 64 | + # For example, say you're rank 1, you'll get an event for pg0 but not pg1 even though the API contact |
| 65 | + # dictates you need to call new_group for both. |
| 66 | + self.assertEqual(len(pgs), 2) |
| 67 | + self.assertEqual(pgs[1][0], pg0 if self.rank < 2 else pg1) |
| 68 | + |
| 69 | + |
| 70 | +def with_comms(func=None): |
| 71 | + if func is None: |
| 72 | + return partial( |
| 73 | + with_comms, |
| 74 | + ) |
| 75 | + |
| 76 | + @wraps(func) |
| 77 | + def wrapper(self, *args, **kwargs): |
| 78 | + self.init_comms() |
| 79 | + func(self, *args, **kwargs) |
| 80 | + self.destroy_comms() |
| 81 | + |
| 82 | + return wrapper |
| 83 | + |
| 84 | + |
| 85 | +class CollectiveHooks: |
| 86 | + @property |
| 87 | + def world_size(self) -> int: |
| 88 | + return 4 |
| 89 | + |
| 90 | + def _collective_hooks(self): |
| 91 | + # it's ok to access them directly since there's a single bg thread poking at them. |
| 92 | + starts = [] |
| 93 | + ends = [] |
| 94 | + cv = threading.Condition() |
| 95 | + |
| 96 | + def coll_start(status): |
| 97 | + starts.append(status) |
| 98 | + print(f"col_start {len(starts)} rank{self.rank}") |
| 99 | + |
| 100 | + def coll_end(status): |
| 101 | + ends.append(status) |
| 102 | + print(f"col_end {len(ends)} rank{self.rank}") |
| 103 | + if len(ends) == 2: |
| 104 | + with cv: |
| 105 | + cv.notify() |
| 106 | + |
| 107 | + dhooks.register_collective_start_hook(coll_start) |
| 108 | + dhooks.register_collective_end_hook(coll_end) |
| 109 | + |
| 110 | + tensor = torch.ones([2, 3]).to(self.device) * self.rank |
| 111 | + tensor_list = [torch.empty_like(tensor) for _ in range(self.world_size)] |
| 112 | + |
| 113 | + dist.all_gather(tensor_list, tensor) |
| 114 | + |
| 115 | + tensor2 = torch.ones([2, 3]).to(self.device) * self.rank |
| 116 | + dist.all_reduce(tensor2) |
| 117 | + |
| 118 | + with cv: |
| 119 | + cv.wait(1) |
| 120 | + |
| 121 | + default_pg_name = dist.group.WORLD.group_name |
| 122 | + self.assertEqual(2, len(starts)) |
| 123 | + self.assertEqual(2, len(ends)) |
| 124 | + |
| 125 | + def check_op(idx, coll_name): |
| 126 | + self.assertEqual(default_pg_name, starts[idx].pg_name) |
| 127 | + self.assertEqual(self.backend_name, starts[idx].backend) |
| 128 | + self.assertGreaterEqual(starts[idx].sequence_number, 0) |
| 129 | + self.assertGreaterEqual(starts[idx].timestamp, 0) |
| 130 | + self.assertEqual(coll_name, starts[idx].operation) |
| 131 | + |
| 132 | + self.assertEqual(default_pg_name, ends[idx].pg_name) |
| 133 | + self.assertEqual(self.backend_name, ends[idx].backend) |
| 134 | + |
| 135 | + self.assertEqual(starts[idx].sequence_number, ends[idx].sequence_number) |
| 136 | + self.assertLessEqual(starts[idx].timestamp, ends[idx].timestamp) |
| 137 | + self.assertEqual(coll_name, ends[idx].operation) |
| 138 | + |
| 139 | + check_op(0, "ALLGATHER") |
| 140 | + check_op(1, "ALLREDUCE") |
| 141 | + |
| 142 | + |
| 143 | +class GlooHooks(MultiProcessTestCase, CollectiveHooks): |
| 144 | + def setUp(self) -> None: |
| 145 | + super().setUp() |
| 146 | + self._spawn_processes() |
| 147 | + |
| 148 | + def tearDown(self): |
| 149 | + super().tearDown() |
| 150 | + try: |
| 151 | + os.remove(self.file_name) |
| 152 | + except OSError: |
| 153 | + pass |
| 154 | + |
| 155 | + def init_comms(self): |
| 156 | + dist.init_process_group( |
| 157 | + backend="gloo", |
| 158 | + rank=self.rank, |
| 159 | + world_size=self.world_size, |
| 160 | + store=dist.FileStore(self.file_name, self.world_size), |
| 161 | + ) |
| 162 | + |
| 163 | + def destroy_comms(self): |
| 164 | + dist.destroy_process_group() |
| 165 | + |
| 166 | + @property |
| 167 | + def backend_name(self): |
| 168 | + return "gloo" |
| 169 | + |
| 170 | + @property |
| 171 | + def device(self): |
| 172 | + return "cpu" |
| 173 | + |
| 174 | + @with_comms |
| 175 | + def test_collective_hooks(self): |
| 176 | + self._collective_hooks() |
| 177 | + |
| 178 | + |
| 179 | +class NcclHooks(MultiProcessTestCase, CollectiveHooks): |
| 180 | + def setUp(self) -> None: |
| 181 | + super().setUp() |
| 182 | + self._spawn_processes() |
| 183 | + |
| 184 | + def tearDown(self): |
| 185 | + super().tearDown() |
| 186 | + try: |
| 187 | + os.remove(self.file_name) |
| 188 | + except OSError: |
| 189 | + pass |
| 190 | + |
| 191 | + def init_comms(self): |
| 192 | + dist.init_process_group( |
| 193 | + backend="nccl", |
| 194 | + rank=self.rank, |
| 195 | + world_size=self.world_size, |
| 196 | + store=dist.FileStore(self.file_name, self.world_size), |
| 197 | + ) |
| 198 | + |
| 199 | + def destroy_comms(self): |
| 200 | + dist.destroy_process_group() |
| 201 | + |
| 202 | + @property |
| 203 | + def backend_name(self): |
| 204 | + return "nccl" |
| 205 | + |
| 206 | + @property |
| 207 | + def device(self): |
| 208 | + return f"cuda:{self.rank}" |
| 209 | + |
| 210 | + @skip_if_lt_x_gpu(4) |
| 211 | + @with_comms |
| 212 | + def test_collective_hooks(self): |
| 213 | + self._collective_hooks() |
| 214 | + |
| 215 | + |
| 216 | +class SingleRankTests(TestCase): |
| 217 | + def setUp(self) -> None: |
| 218 | + super().setUp() |
| 219 | + self.rank = 0 |
| 220 | + self.file_name = tempfile.NamedTemporaryFile(delete=False).name |
| 221 | + dist.init_process_group( |
| 222 | + backend="gloo", |
| 223 | + rank=0, |
| 224 | + world_size=1, |
| 225 | + store=dist.FileStore(self.file_name, 1), |
| 226 | + ) |
| 227 | + |
| 228 | + def tearDown(self) -> None: |
| 229 | + dist.destroy_process_group() |
| 230 | + |
| 231 | + def test_queue_overflow(self) -> None: |
| 232 | + cv_done_colls = threading.Condition() |
| 233 | + cv_done_cb = threading.Condition() |
| 234 | + colls_done = False |
| 235 | + starts = [] |
| 236 | + status_with_dropped = None |
| 237 | + |
| 238 | + def coll_start(status: dhooks.CollectiveStatus): |
| 239 | + starts.append(status) |
| 240 | + with cv_done_colls: |
| 241 | + while not colls_done: |
| 242 | + cv_done_colls.wait() |
| 243 | + if status.drop_count > 0: |
| 244 | + nonlocal status_with_dropped |
| 245 | + status_with_dropped = status |
| 246 | + with cv_done_cb: |
| 247 | + cv_done_cb.notify() |
| 248 | + |
| 249 | + dhooks.register_collective_start_hook(coll_start) |
| 250 | + |
| 251 | + # native limit is 512 |
| 252 | + for i in range(600): |
| 253 | + dist.all_reduce(torch.ones([2, 3])) |
| 254 | + colls_done = True |
| 255 | + with cv_done_colls: |
| 256 | + cv_done_colls.notify() |
| 257 | + |
| 258 | + with cv_done_cb: |
| 259 | + cv_done_cb.wait(10) |
| 260 | + |
| 261 | + self.assertTrue(status_with_dropped is not None) |
| 262 | + self.assertTrue(status_with_dropped.drop_count > 0) |
| 263 | + |
| 264 | + |
| 265 | +if __name__ == "__main__": |
| 266 | + assert ( |
| 267 | + not torch.cuda._initialized |
| 268 | + ), "test_distributed must not have initialized CUDA context on main process" |
| 269 | + |
| 270 | + run_tests() |
0 commit comments