Skip to content

Commit eb13ba2

Browse files
authored
Provide load_seed_checkpoint_only option (#1800)
Seed checkpoint loading requires checkpoint.enable to be True. However, when using seed checkpoints, users typically don't want to save subsequent checkpoints since seed checkpoints serve verification purposes only.
1 parent 96149f6 commit eb13ba2

File tree

4 files changed

+81
-1
lines changed

4 files changed

+81
-1
lines changed

docs/debugging.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,24 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_tr
9898

9999
For multiple experimental runs with different parallelism configs, we need to use a "seed" checkpoint to ensure model initializations are the same across runs. This is because in `torchtitan/train.py`, the model parameters are sharded first, and then have their weights initialized on each rank separately. As a result, it is not equivalent to initialize the model on one rank and then shard it. Using a seed checkpoint helps different runs load the same model weights from checkpoint -- DCP resharding will make sure the loaded weights are sharded correctly according to the parallelism configs.
100100

101+
#### Creating a Seed Checkpoint
101102

102103
```bash
103104
NGPU=1 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --checkpoint.enable --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1
104105
```
105106

107+
#### Loading Seed Checkpoints for Debugging
108+
109+
When using seed checkpoints for debugging or validation purposes, you can enable the `load_only` configuration to load checkpoints without saving any new ones during training. This is particularly useful when you only want to verify model correctness or compare different configurations without cluttering your disk:
110+
111+
```bash
112+
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --checkpoint.enable --checkpoint.load_only
113+
```
114+
115+
The `--checkpoint.load_only` flag prevents the training process from saving any checkpoints, allowing you to:
116+
- Run debugging sessions without generating unwanted checkpoint files
117+
- Compare model behaviors using the same initial weights without checkpoint overhead
118+
106119
**Note**: Using a seed checkpoint will only make sure a model has same initial weights when configs change, but the training process may not be the same even after setting the seed and the `deterministic` mode, e.g. due to tensor shape change, data precision change, usage of randomness in model code, etc.
107120

108121
### Example: Reproducing loss curves with different parallelism configs

tests/unit_tests/test_checkpoint.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,64 @@ def __init__(self):
675675

676676
manager.close()
677677

678+
@mock.patch("torch.distributed.get_rank", return_value=0)
679+
@mock.patch("torchtitan.components.checkpoint.dcp.save")
680+
def test_load_only_prevents_saving(self, mock_save, mock_rank):
681+
"""
682+
Test that load_only=True prevents checkpoint saving.
683+
"""
684+
mock_save.side_effect = self.fake_save
685+
686+
# Configure load_only=True
687+
cfg = self.job_config.checkpoint
688+
cfg.load_only = True
689+
cfg.interval = 1 # Set low interval to ensure saves would normally trigger
690+
691+
manager = CheckpointManager(
692+
dataloader=self.data_loader,
693+
model_parts=self.model_parts,
694+
optimizers=self.optimizers,
695+
lr_schedulers=self.lr_schedulers,
696+
states=self.states,
697+
checkpoint_config=self.job_config.checkpoint,
698+
sd_adapter=None,
699+
base_folder=self.job_config.job.dump_folder,
700+
ft_manager=self.ft_manager,
701+
)
702+
703+
# Test various save conditions that would normally trigger saves
704+
manager.save(curr_step=1) # Regular step save
705+
self.assertEqual(mock_save.call_count, 0)
706+
707+
manager.save(curr_step=5) # Interval-based save
708+
self.assertEqual(mock_save.call_count, 0)
709+
710+
manager.save(curr_step=10, last_step=True) # Last step save
711+
self.assertEqual(mock_save.call_count, 0)
712+
713+
manager.close()
714+
715+
# Verify that saves work normally when load_only=False
716+
mock_save.reset_mock()
717+
cfg.load_only = False
718+
719+
manager2 = CheckpointManager(
720+
dataloader=self.data_loader,
721+
model_parts=self.model_parts,
722+
optimizers=self.optimizers,
723+
lr_schedulers=self.lr_schedulers,
724+
states=self.states,
725+
checkpoint_config=self.job_config.checkpoint,
726+
sd_adapter=None,
727+
base_folder=self.job_config.job.dump_folder,
728+
ft_manager=self.ft_manager,
729+
)
730+
731+
manager2.save(curr_step=1) # Should trigger save now
732+
self.assertEqual(mock_save.call_count, 1)
733+
734+
manager2.close()
735+
678736
@mock.patch("torch.distributed.get_rank", return_value=0)
679737
@mock.patch("torchtitan.components.checkpoint.dcp.load")
680738
@mock.patch("torchtitan.components.checkpoint.dcp.save")

torchtitan/components/checkpoint.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def __init__(
190190
ft_manager: FTManager | None = None,
191191
) -> None:
192192
self.enable = checkpoint_config.enable
193+
self.load_only = checkpoint_config.load_only
193194

194195
self.ft_manager = (
195196
ft_manager.manager if ft_manager and ft_manager.enabled else None
@@ -761,7 +762,7 @@ def _save_last_step(self, curr_step: int) -> None:
761762
)
762763

763764
def _should_save(self, curr_step: int, last_step: bool = False) -> bool:
764-
if not self.enable:
765+
if not self.enable or self.load_only:
765766
return False
766767

767768
if curr_step == 1 and self.enable_first_step_checkpoint:

torchtitan/config/job_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,14 @@ class Checkpoint:
527527
Could be implemented as a separate script, but this way shares more code.
528528
"""
529529

530+
load_only: bool = False
531+
"""
532+
In certain scenarios, you may only need to load checkpoints for verification or debugging
533+
purposes, without saving any new checkpoints. For example, you might use seed checkpoints
534+
to validate model correctness. Enabling this option allows checkpoints to be loaded
535+
without saving any during the training.
536+
"""
537+
530538

531539
@dataclass
532540
class ActivationCheckpoint:

0 commit comments

Comments
 (0)