Skip to content

Commit 06cfde3

Browse files
Kiuk Chungfacebook-github-bot
Kiuk Chung
authored andcommitted
Delete classy elastic trainer and dependency to pet checkpoint api (pytorch#80)
Summary: Pull Request resolved: pytorch#80 Pull Request resolved: facebookresearch/ClassyVision#464 Obsolete code as of D20787422. Reviewed By: vreis Differential Revision: D20787751 fbshipit-source-id: d0bba49902467e9f117ec6ac8199a7fc05b91ab4
1 parent fd92ee6 commit 06cfde3

File tree

2 files changed

+1
-62
lines changed

2 files changed

+1
-62
lines changed

test/p2p/elastic_trainer_test_base.py

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -624,60 +624,6 @@ def test_sync_retryable_exception(self):
624624
# are retryable / non-fatal
625625
self.assertEqual([410, 410, 410, 410], sums)
626626

627-
def test_checkpoint(self):
628-
"""
629-
Test with 4 trainers:
630-
- Save checkpoint every train_step
631-
- Trainers suicide at 3rd step
632-
- Restart training (from checkpoint)
633-
"""
634-
635-
def process_crash():
636-
log.warning("Suicide, pid:{}".format(os.getpid()))
637-
os.kill(os.getpid(), signal.SIGKILL)
638-
639-
hooks = {"process_crash": process_crash}
640-
run_id = self._generate_run_id()
641-
642-
nprocs = 4
643-
644-
# Before training, there is no checkpoint
645-
checkpoint_manager = FileSystemCheckpointManager(self.test_dir.name)
646-
self.assertEqual(0, len(checkpoint_manager.list_checkpoints()))
647-
648-
for _ in range(0, nprocs):
649-
_, qout, qerr = self._spawn(
650-
self._train_with_checkpoint, run_id, _train_step, hooks
651-
)
652-
653-
# wait all training process complete
654-
# clean up for next run
655-
self._wait_all_and_clean()
656-
657-
# we run 2 steps before suicide, expect two checkpoints be saved
658-
self.assertEqual(2, len(checkpoint_manager.list_checkpoints()))
659-
660-
qouts = []
661-
qerrs = []
662-
# start next run
663-
for _ in range(0, nprocs):
664-
_, qout, qerr = self._spawn(
665-
self._train_with_checkpoint, run_id, _train_step, None
666-
)
667-
qouts.append(qout)
668-
qerrs.append(qerr)
669-
670-
# Gather all nums and sums from final states, they should match the input
671-
sums = []
672-
for i in range(0, nprocs):
673-
state = _get_or_raise(qouts[i], qerrs[i])
674-
# Everyone reads 3 samples after recovering from checkpoint:
675-
self.assertEqual(3, len(state.nums))
676-
sums.append(state.total_sum)
677-
678-
# The job should be completely recovered through checkpoints / crashes:
679-
self.assertEqual([410, 410, 410, 410], sums)
680-
681627
def test_process_crash(self):
682628
"""
683629
Test 4 trainers, 2 of which SIGKILL themselves and terminate.

torchelastic/train_loop.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import warnings
1212

1313
import torchelastic
14-
from torchelastic.checkpoint import CheckpointUtil
1514
from torchelastic.coordinator import NonRetryableException, StopException
1615
from torchelastic.metrics import get_elapsed_time_ms, publish_metric
1716

@@ -69,8 +68,6 @@ def run_train(coordinator, train_step_gen, state):
6968
failure_count = 0
7069
rank = 0
7170

72-
checkpoint_util = CheckpointUtil(coordinator)
73-
7471
while not coordinator.should_stop_training():
7572
# See: https://github.com/pytorch/elastic/issues/7
7673
if failure_count >= MAX_FAILURES:
@@ -90,17 +87,14 @@ def run_train(coordinator, train_step_gen, state):
9087
# does not sync.
9188
coordinator.barrier()
9289

93-
# load checkpoint if necessary
94-
state = checkpoint_util.load_checkpoint(state, rank)
95-
9690
state_sync_start_time = time.time()
9791
state.sync(world_size, rank)
9892
publish_metric(
9993
"torchelastic",
10094
"state_sync.duration.ms",
10195
get_elapsed_time_ms(state_sync_start_time),
10296
)
103-
checkpoint_util.set_checkpoint_loaded()
97+
10498
coordinator.barrier()
10599
log.info("Rank {0} synced state with other nodes".format(rank))
106100
except StopException:
@@ -140,7 +134,6 @@ def run_train(coordinator, train_step_gen, state):
140134

141135
coordinator.monitor_progress(state, worker_stats)
142136

143-
checkpoint_util.save_checkpoint(state, rank)
144137
if coordinator.should_rendezvous(state):
145138
log.info("Rank {0} will re-rendezvous".format(rank))
146139
# Executor told us, for whatever reason, to re-rendezvous.

0 commit comments

Comments
 (0)