Skip to content

Commit bb1424d

Browse files
wconstabpytorchmergebot
authored andcommitted
Reland #2 "[C10] PG observability hooks. (pytorch#108815, pytorch#110907)" (pytorch#111072)
This reverts commit 314a502. Changes since original PR: Reland 1 * rename torch.distributed.hooks to torch.distributed._hooks Reland 2 * make _hooks importable even if !distributed.is_available() * handle cuda driver exit intermittent failure caused by new cuda api usage in callback caller (see prev PR in stack) (original PR pytorch#108815 desc copied below) Expose a set of observability hooks into C10D such that our users can detect collectives failure both faster and more easily. The design is similar to NCCL desync debug that it minimized the overhead by doing most of the work out of the main thread. This PR introduces a new module torch.distributed.hooks that exposes the following set of methods: register_collective_start_hook register_collective_end_hook register_process_group_hook The process group hook exposes PG creation on the member ranks and call them inline from the the PG creation code. This is fine since this happens during initialization and a limited number of times. The collective start/end hooks are fired from a single background thread. It reads events from a C++ queue and dispatches over. Queue notification is oddly done using a pipe, this is needed so python can abort the thread on shutdown and have it as background thread. This is not possible with more reasonable choices like a condvar. Pull Request resolved: pytorch#111072 Approved by: https://github.com/malfet ghstack dependencies: pytorch#111061
1 parent dede1e9 commit bb1424d

File tree

17 files changed

+702
-29
lines changed

17 files changed

+702
-29
lines changed

build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,7 @@ libtorch_distributed_base_sources = [
522522
"torch/csrc/distributed/c10d/Backend.cpp",
523523
"torch/csrc/distributed/c10d/FileStore.cpp",
524524
"torch/csrc/distributed/c10d/GlooDeviceFactory.cpp",
525+
"torch/csrc/distributed/c10d/Hooks.cpp",
525526
"torch/csrc/distributed/c10d/Ops.cpp",
526527
"torch/csrc/distributed/c10d/ParamCommsUtils.cpp",
527528
"torch/csrc/distributed/c10d/PrefixStore.cpp",

test/distributed/test_hooks.py

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
# Owner(s): ["oncall: distributed"]
2+
3+
import os
4+
import sys
5+
import tempfile
6+
import threading
7+
from functools import partial, wraps
8+
9+
import torch
10+
import torch.distributed as dist
11+
import torch.distributed._hooks as dhooks
12+
13+
if not dist.is_available():
14+
print("torch.distributed not available, skipping tests", file=sys.stderr)
15+
sys.exit(0)
16+
17+
18+
from torch.testing._internal.common_distributed import (
19+
MultiProcessTestCase,
20+
skip_if_lt_x_gpu,
21+
)
22+
23+
from torch.testing._internal.common_utils import run_tests, TestCase
24+
25+
26+
class PgHooks(MultiProcessTestCase):
27+
@property
28+
def world_size(self) -> int:
29+
return 4
30+
31+
def setUp(self) -> None:
32+
super().setUp()
33+
self._spawn_processes()
34+
35+
def tearDown(self):
36+
super().tearDown()
37+
try:
38+
os.remove(self.file_name)
39+
except OSError:
40+
pass
41+
42+
def test_pg_hook(self):
43+
pgs = []
44+
45+
def pg_hook(pg, pg_name):
46+
pgs.append((pg, pg_name))
47+
48+
dhooks.register_process_group_hook(pg_hook)
49+
dist.init_process_group(
50+
backend="gloo",
51+
rank=self.rank,
52+
world_size=self.world_size,
53+
store=dist.FileStore(self.file_name, self.world_size),
54+
)
55+
self.assertEqual(len(pgs), 1)
56+
self.assertEqual(pgs[0][0], dist.group.WORLD)
57+
58+
# create two partial world PGs
59+
pg0 = dist.new_group(ranks=[0, 1])
60+
pg1 = dist.new_group(ranks=[2, 3])
61+
62+
# Each rank only observe two PGs being created: the default PG and one covering its ranks
63+
# We don't emit events for PG creation if the current rank doesn't belong to it.
64+
# For example, say you're rank 1, you'll get an event for pg0 but not pg1 even though the API contact
65+
# dictates you need to call new_group for both.
66+
self.assertEqual(len(pgs), 2)
67+
self.assertEqual(pgs[1][0], pg0 if self.rank < 2 else pg1)
68+
69+
70+
def with_comms(func=None):
71+
if func is None:
72+
return partial(
73+
with_comms,
74+
)
75+
76+
@wraps(func)
77+
def wrapper(self, *args, **kwargs):
78+
self.init_comms()
79+
func(self, *args, **kwargs)
80+
self.destroy_comms()
81+
82+
return wrapper
83+
84+
85+
class CollectiveHooks:
86+
@property
87+
def world_size(self) -> int:
88+
return 4
89+
90+
def _collective_hooks(self):
91+
# it's ok to access them directly since there's a single bg thread poking at them.
92+
starts = []
93+
ends = []
94+
cv = threading.Condition()
95+
96+
def coll_start(status):
97+
starts.append(status)
98+
print(f"col_start {len(starts)} rank{self.rank}")
99+
100+
def coll_end(status):
101+
ends.append(status)
102+
print(f"col_end {len(ends)} rank{self.rank}")
103+
if len(ends) == 2:
104+
with cv:
105+
cv.notify()
106+
107+
dhooks.register_collective_start_hook(coll_start)
108+
dhooks.register_collective_end_hook(coll_end)
109+
110+
tensor = torch.ones([2, 3]).to(self.device) * self.rank
111+
tensor_list = [torch.empty_like(tensor) for _ in range(self.world_size)]
112+
113+
dist.all_gather(tensor_list, tensor)
114+
115+
tensor2 = torch.ones([2, 3]).to(self.device) * self.rank
116+
dist.all_reduce(tensor2)
117+
118+
with cv:
119+
cv.wait(1)
120+
121+
default_pg_name = dist.group.WORLD.group_name
122+
self.assertEqual(2, len(starts))
123+
self.assertEqual(2, len(ends))
124+
125+
def check_op(idx, coll_name):
126+
self.assertEqual(default_pg_name, starts[idx].pg_name)
127+
self.assertEqual(self.backend_name, starts[idx].backend)
128+
self.assertGreaterEqual(starts[idx].sequence_number, 0)
129+
self.assertGreaterEqual(starts[idx].timestamp, 0)
130+
self.assertEqual(coll_name, starts[idx].operation)
131+
132+
self.assertEqual(default_pg_name, ends[idx].pg_name)
133+
self.assertEqual(self.backend_name, ends[idx].backend)
134+
135+
self.assertEqual(starts[idx].sequence_number, ends[idx].sequence_number)
136+
self.assertLessEqual(starts[idx].timestamp, ends[idx].timestamp)
137+
self.assertEqual(coll_name, ends[idx].operation)
138+
139+
check_op(0, "ALLGATHER")
140+
check_op(1, "ALLREDUCE")
141+
142+
143+
class GlooHooks(MultiProcessTestCase, CollectiveHooks):
144+
def setUp(self) -> None:
145+
super().setUp()
146+
self._spawn_processes()
147+
148+
def tearDown(self):
149+
super().tearDown()
150+
try:
151+
os.remove(self.file_name)
152+
except OSError:
153+
pass
154+
155+
def init_comms(self):
156+
dist.init_process_group(
157+
backend="gloo",
158+
rank=self.rank,
159+
world_size=self.world_size,
160+
store=dist.FileStore(self.file_name, self.world_size),
161+
)
162+
163+
def destroy_comms(self):
164+
dist.destroy_process_group()
165+
166+
@property
167+
def backend_name(self):
168+
return "gloo"
169+
170+
@property
171+
def device(self):
172+
return "cpu"
173+
174+
@with_comms
175+
def test_collective_hooks(self):
176+
self._collective_hooks()
177+
178+
179+
class NcclHooks(MultiProcessTestCase, CollectiveHooks):
180+
def setUp(self) -> None:
181+
super().setUp()
182+
self._spawn_processes()
183+
184+
def tearDown(self):
185+
super().tearDown()
186+
try:
187+
os.remove(self.file_name)
188+
except OSError:
189+
pass
190+
191+
def init_comms(self):
192+
dist.init_process_group(
193+
backend="nccl",
194+
rank=self.rank,
195+
world_size=self.world_size,
196+
store=dist.FileStore(self.file_name, self.world_size),
197+
)
198+
199+
def destroy_comms(self):
200+
dist.destroy_process_group()
201+
202+
@property
203+
def backend_name(self):
204+
return "nccl"
205+
206+
@property
207+
def device(self):
208+
return f"cuda:{self.rank}"
209+
210+
@skip_if_lt_x_gpu(4)
211+
@with_comms
212+
def test_collective_hooks(self):
213+
self._collective_hooks()
214+
215+
216+
class SingleRankTests(TestCase):
217+
def setUp(self) -> None:
218+
super().setUp()
219+
self.rank = 0
220+
self.file_name = tempfile.NamedTemporaryFile(delete=False).name
221+
dist.init_process_group(
222+
backend="gloo",
223+
rank=0,
224+
world_size=1,
225+
store=dist.FileStore(self.file_name, 1),
226+
)
227+
228+
def tearDown(self) -> None:
229+
dist.destroy_process_group()
230+
231+
def test_queue_overflow(self) -> None:
232+
cv_done_colls = threading.Condition()
233+
cv_done_cb = threading.Condition()
234+
colls_done = False
235+
starts = []
236+
status_with_dropped = None
237+
238+
def coll_start(status: dhooks.CollectiveStatus):
239+
starts.append(status)
240+
with cv_done_colls:
241+
while not colls_done:
242+
cv_done_colls.wait()
243+
if status.drop_count > 0:
244+
nonlocal status_with_dropped
245+
status_with_dropped = status
246+
with cv_done_cb:
247+
cv_done_cb.notify()
248+
249+
dhooks.register_collective_start_hook(coll_start)
250+
251+
# native limit is 512
252+
for i in range(600):
253+
dist.all_reduce(torch.ones([2, 3]))
254+
colls_done = True
255+
with cv_done_colls:
256+
cv_done_colls.notify()
257+
258+
with cv_done_cb:
259+
cv_done_cb.wait(10)
260+
261+
self.assertTrue(status_with_dropped is not None)
262+
self.assertTrue(status_with_dropped.drop_count > 0)
263+
264+
265+
if __name__ == "__main__":
266+
assert (
267+
not torch.cuda._initialized
268+
), "test_distributed must not have initialized CUDA context on main process"
269+
270+
run_tests()

torch/_C/_distributed_c10d.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ _DEFAULT_FIRST_BUCKET_BYTES: int
1111
_DEFAULT_NO_TIMEOUT: timedelta
1212
_DEFAULT_PG_TIMEOUT: timedelta
1313

14+
class EventKind(Enum):
15+
START = ...
16+
END = ...
17+
1418
class BuiltinCommHookType(Enum):
1519
ALLREDUCE = ...
1620
FP16_COMPRESS = ...
@@ -20,6 +24,8 @@ def _register_builtin_comm_hook(
2024
reducer: Reducer,
2125
comm_hook_type: BuiltinCommHookType,
2226
): ...
27+
def _dequeue_c10d_event() -> Dict[str, object]: ...
28+
def _enable_event_collection(pipe_fs: int) -> None: ...
2329

2430
class GradBucket:
2531
def index(self) -> int: ...

torch/csrc/distributed/c10d/Backend.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,26 @@
11
#include <c10/util/Logging.h>
22
#include <fmt/format.h>
33
#include <torch/csrc/distributed/c10d/Backend.hpp>
4+
#include <torch/csrc/distributed/c10d/Hooks.hpp>
5+
#include <torch/csrc/distributed/c10d/logging.h>
46

57
namespace c10d {
68

9+
namespace {
10+
void commonEventinit(
11+
details::EventInfo& evt,
12+
const Backend& backend,
13+
const Work& work) {
14+
evt.timestamp =
15+
std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
16+
evt.pg_name = backend.getGroupName();
17+
evt.backend = backend.getBackendName();
18+
evt.sequence_number = work.getSequencenumber();
19+
evt.operation = c10d::opTypeToString(work.retrieveOpType());
20+
evt.drop_count = 0;
21+
}
22+
} // namespace
23+
724
Backend::Backend(int rank, int size)
825
: rank_(rank), size_(size), dist_debug_level_(debug_level()) {
926
C10_LOG_API_USAGE_ONCE("c10d.backend");
@@ -15,4 +32,21 @@ void Backend::init() {
1532
C10_LOG_API_USAGE_ONCE(fmt::format("c10d.backend_{}", getBackendName()));
1633
}
1734

35+
void Backend::emitCollectiveStart(const Work& work) {
36+
details::EventInfo evt;
37+
commonEventinit(evt, *this, work);
38+
39+
evt.event_kind = ::c10d::EventKind::CollectiveStart;
40+
details::enqueue_c10d_event(std::move(evt));
41+
}
42+
43+
void Backend::emitCollectiveEnd(const Work& work) {
44+
details::EventInfo evt;
45+
commonEventinit(evt, *this, work);
46+
47+
evt.event_kind = ::c10d::EventKind::CollectiveEnd;
48+
evt.duration_ms = work.getDuration();
49+
details::enqueue_c10d_event(std::move(evt));
50+
}
51+
1852
} // namespace c10d

torch/csrc/distributed/c10d/Backend.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,8 @@ class TORCH_API Backend : public torch::CustomClassHolder {
366366
// Implementations of this interface need to call this to setup
367367
// appropriate logging etc.
368368
void init();
369+
void emitCollectiveStart(const Work& work);
370+
void emitCollectiveEnd(const Work& work);
369371

370372
const int rank_;
371373
const int size_;

0 commit comments

Comments
 (0)