Skip to content

Commit ccfdb30

Browse files
VitalyFedyuninfacebook-github-bot
authored andcommitted
Revert D29413019: [torch] Various improvements to torch.distributed.launch and torch.distributed.run
Test Plan: revert-hammer Differential Revision: D29413019 (pytorch@4e181df) Original commit changeset: 323bfbad9d0e fbshipit-source-id: 1f8ae4b3d0a23f3eaff28c37e9148efff25fafe2
1 parent 48bfc0e commit ccfdb30

File tree

12 files changed

+57
-148
lines changed

12 files changed

+57
-148
lines changed

docs/source/elastic/errors.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
.. _elastic_errors-api:
2-
31
Error Propagation
42
==================
53

docs/source/elastic/run.rst

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
.. _launcher-api:
22

3-
torch.distributed.run (Elastic Launch)
4-
======================================
3+
Elastic Launch
4+
============================
5+
6+
torch.distributed.run
7+
----------------------
58

69
.. automodule:: torch.distributed.run

docs/source/elastic/train_script.rst

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
.. _elastic_train_script:
2-
31
Train script
42
-------------
53

@@ -9,20 +7,18 @@ working with ``torch.distributed.run`` with these differences:
97
1. No need to manually pass ``RANK``, ``WORLD_SIZE``,
108
``MASTER_ADDR``, and ``MASTER_PORT``.
119

12-
2. ``rdzv_backend`` and ``rdzv_endpoint`` can be provided. For most users
13-
this will be set to ``c10d`` (see `rendezvous <rendezvous.html>`_). The default
14-
``rdzv_backend`` creates a non-elastic rendezvous where ``rdzv_endpoint`` holds
15-
the master address.
10+
2. ``rdzv_backend`` and ``rdzv_endpoint`` must be provided. For most users
11+
this will be set to ``c10d`` (see `rendezvous <rendezvous.html>`_).
1612

1713
3. Make sure you have a ``load_checkpoint(path)`` and
18-
``save_checkpoint(path)`` logic in your script. When any number of
19-
workers fail we restart all the workers with the same program
20-
arguments so you will lose progress up to the most recent checkpoint
14+
``save_checkpoint(path)`` logic in your script. When workers fail
15+
we restart all the workers with the same program arguments so you will
16+
lose progress up to the most recent checkpoint
2117
(see `elastic launch <distributed.html>`_).
2218

2319
4. ``use_env`` flag has been removed. If you were parsing local rank by parsing
2420
the ``--local_rank`` option, you need to get the local rank from the
25-
environment variable ``LOCAL_RANK`` (e.g. ``int(os.environ["LOCAL_RANK"])``).
21+
environment variable ``LOCAL_RANK`` (e.g. ``os.environ["LOCAL_RANK"]``).
2622

2723
Below is an expository example of a training script that checkpoints on each
2824
epoch, hence the worst-case progress lost on failure is one full epoch worth
@@ -35,7 +31,7 @@ of training.
3531
state = load_checkpoint(args.checkpoint_path)
3632
initialize(state)
3733
38-
# torch.distributed.run ensures that this will work
34+
# torch.distributed.run ensure that this will work
3935
# by exporting all the env vars needed to initialize the process group
4036
torch.distributed.init_process_group(backend=args.backend)
4137

torch/distributed/elastic/agent/server/local_elastic_agent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult:
205205
result = self._pcontext.wait(0)
206206
if result:
207207
if result.is_failed():
208+
log.error(f"[{role}] Worker group failed")
208209
# map local rank failure to global rank
209210
worker_failures = {}
210211
for local_rank, failure in result.failures.items():

torch/distributed/elastic/events/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
2020
"""
2121

22-
import os
2322
import logging
2423

2524
from torch.distributed.elastic.events.handlers import get_logging_handler
@@ -47,12 +46,12 @@ def _get_or_create_logger(destination: str = "null") -> logging.Logger:
4746
return _events_logger
4847
logging_handler = get_logging_handler(destination)
4948
_events_logger = logging.getLogger(f"torchelastic-events-{destination}")
50-
_events_logger.setLevel(os.environ.get("LOGLEVEL", "INFO"))
49+
_events_logger.setLevel(logging.DEBUG)
5150
# Do not propagate message to the root logger
5251
_events_logger.propagate = False
5352
_events_logger.addHandler(logging_handler)
5453
return _events_logger
5554

5655

57-
def record(event: Event, destination: str = "null") -> None:
56+
def record(event: Event, destination: str = "console") -> None:
5857
_get_or_create_logger(destination).info(event.serialize())

torch/distributed/elastic/events/handlers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@
1212

1313
_log_handlers: Dict[str, logging.Handler] = {
1414
"console": logging.StreamHandler(),
15-
"null": logging.NullHandler(),
1615
}
1716

1817

19-
def get_logging_handler(destination: str = "null") -> logging.Handler:
18+
def get_logging_handler(destination: str = "console") -> logging.Handler:
2019
return _log_handlers[destination]

torch/distributed/elastic/multiprocessing/api.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -497,17 +497,6 @@ def close(self):
497497
self._stderr.close()
498498

499499

500-
def _pr_set_pdeathsig() -> None:
501-
"""
502-
Sets PR_SET_PDEATHSIG to ensure a child process is
503-
terminated appropriately.
504-
505-
See http://stackoverflow.com/questions/1884941/ for more information.
506-
For libc.so.6 read http://www.linux-m68k.org/faq/glibcinfo.html
507-
"""
508-
mp._prctl_pr_set_pdeathsig(signal.SIGTERM) # type: ignore[attr-defined]
509-
510-
511500
class SubprocessContext(PContext):
512501
"""
513502
``PContext`` holding worker processes invoked as a binary.
@@ -552,7 +541,7 @@ def _start(self):
552541
entrypoint=self.entrypoint, # type: ignore[arg-type] # entrypoint is always a str
553542
args=self.args[local_rank],
554543
env=self.envs[local_rank],
555-
preexec_fn=_pr_set_pdeathsig,
544+
preexec_fn=mp._prctl_pr_set_pdeathsig(signal.SIGTERM), # type: ignore[attr-defined]
556545
stdout=self.stdouts[local_rank],
557546
stderr=self.stderrs[local_rank],
558547
)

torch/distributed/elastic/utils/logging.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def get_logger(name: Optional[str] = None):
1717
"""
1818
Util function to set up a simple logger that writes
1919
into stderr. The loglevel is fetched from the LOGLEVEL
20-
env. variable or WARNING as default. The function will use the
20+
env. variable or INFO as default. The function will use the
2121
module name of the caller if no name is provided.
2222
2323
Args:
@@ -32,7 +32,7 @@ def get_logger(name: Optional[str] = None):
3232

3333
def _setup_logger(name: Optional[str] = None):
3434
log = logging.getLogger(name)
35-
log.setLevel(os.environ.get("LOGLEVEL", "WARNING"))
35+
log.setLevel(os.environ.get("LOGLEVEL", "INFO"))
3636
return log
3737

3838

torch/distributed/elastic/utils/store.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# This source code is licensed under the BSD-style license found in the
77
# LICENSE file in the root directory of this source tree.
88

9+
import warnings
910
from datetime import timedelta
1011
from typing import List
1112

@@ -63,5 +64,8 @@ def barrier(
6364
Note: Since the data is not removed from the store, the barrier can be used
6465
once per unique ``key_prefix``.
6566
"""
67+
warnings.warn(
68+
"This is an experimental API and will be changed in future.", FutureWarning
69+
)
6670
data = f"{rank}".encode(encoding="UTF-8")
6771
synchronize(store, data, rank, world_size, key_prefix, barrier_timeout)

torch/distributed/launch.py

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
r"""
2-
``torch.distributed.launch`` is a module that spawns up multiple distributed
2+
`torch.distributed.launch` is a module that spawns up multiple distributed
33
training processes on each of the training nodes.
44
5-
.. warning::
6-
7-
This module is going to be deprecated in favor of :ref:`torch.distributed.run <launcher-api>`.
5+
NOTE: This module is deprecated, use torch.distributed.run.
86
97
The utility can be used for single-node distributed training, in which one or
108
more processes per node will be spawned. The utility can be used for either
@@ -138,12 +136,9 @@
138136
https://github.com/pytorch/pytorch/issues/12042 for an example of
139137
how things can go wrong if you don't do this correctly.
140138
141-
142-
143139
"""
144140

145141
import logging
146-
import warnings
147142

148143
from torch.distributed.run import get_args_parser, run
149144

@@ -164,27 +159,14 @@ def parse_args(args):
164159
return parser.parse_args(args)
165160

166161

167-
def launch(args):
168-
if args.no_python and not args.use_env:
169-
raise ValueError(
170-
"When using the '--no_python' flag,"
171-
" you must also set the '--use_env' flag."
172-
)
173-
run(args)
174-
175-
176162
def main(args=None):
177-
warnings.warn(
178-
"The module torch.distributed.launch is deprecated\n"
179-
"and will be removed in future. Use torch.distributed.run.\n"
180-
"Note that --use_env is set by default in torch.distributed.run.\n"
181-
"If your script expects `--local_rank` argument to be set, please\n"
182-
"change it to read from `os.environ('LOCAL_RANK')` instead. See \n"
183-
"https://pytorch.org/docs/stable/distributed.html#launch-utility for \n"
184-
"further instructions\n", FutureWarning
163+
logger.warning(
164+
"The module torch.distributed.launch is deprecated "
165+
"and going to be removed in future."
166+
"Migrate to torch.distributed.run"
185167
)
186168
args = parse_args(args)
187-
launch(args)
169+
run(args)
188170

189171

190172
if __name__ == "__main__":

0 commit comments

Comments
 (0)