1
- .. _elastic_train_script :
2
-
3
1
Train script
4
2
-------------
5
3
@@ -9,20 +7,18 @@ working with ``torch.distributed.run`` with these differences:
9
7
1. No need to manually pass ``RANK ``, ``WORLD_SIZE ``,
10
8
``MASTER_ADDR ``, and ``MASTER_PORT ``.
11
9
12
- 2. ``rdzv_backend `` and ``rdzv_endpoint `` can be provided. For most users
13
- this will be set to ``c10d `` (see `rendezvous <rendezvous.html >`_). The default
14
- ``rdzv_backend `` creates a non-elastic rendezvous where ``rdzv_endpoint `` holds
15
- the master address.
10
+ 2. ``rdzv_backend `` and ``rdzv_endpoint `` must be provided. For most users
11
+ this will be set to ``c10d `` (see `rendezvous <rendezvous.html >`_).
16
12
17
13
3. Make sure you have a ``load_checkpoint(path) `` and
18
- ``save_checkpoint(path) `` logic in your script. When any number of
19
- workers fail we restart all the workers with the same program
20
- arguments so you will lose progress up to the most recent checkpoint
14
+ ``save_checkpoint(path) `` logic in your script. When workers fail
15
+ we restart all the workers with the same program arguments so you will
16
+ lose progress up to the most recent checkpoint
21
17
(see `elastic launch <distributed.html >`_).
22
18
23
19
4. ``use_env `` flag has been removed. If you were parsing local rank by parsing
24
20
the ``--local_rank `` option, you need to get the local rank from the
25
- environment variable ``LOCAL_RANK `` (e.g. ``int( os.environ["LOCAL_RANK"]) ``).
21
+ environment variable ``LOCAL_RANK `` (e.g. ``os.environ["LOCAL_RANK"] ``).
26
22
27
23
Below is an expository example of a training script that checkpoints on each
28
24
epoch, hence the worst-case progress lost on failure is one full epoch worth
@@ -35,7 +31,7 @@ of training.
35
31
state = load_checkpoint(args.checkpoint_path)
36
32
initialize(state)
37
33
38
- # torch.distributed.run ensures that this will work
34
+ # torch.distributed.run ensure that this will work
39
35
# by exporting all the env vars needed to initialize the process group
40
36
torch.distributed.init_process_group(backend = args.backend)
41
37
0 commit comments