diff --git a/docs/debugging.md b/docs/debugging.md index bc683bd9b..f7758cbde 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -98,11 +98,24 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_tr For multiple experimental runs with different parallelism configs, we need to use a "seed" checkpoint to ensure model initializations are the same across runs. This is because in `torchtitan/train.py`, the model parameters are sharded first, and then have their weights initialized on each rank separately. As a result, it is not equivalent to initialize the model on one rank and then shard it. Using a seed checkpoint helps different runs load the same model weights from checkpoint -- DCP resharding will make sure the loaded weights are sharded correctly according to the parallelism configs. +#### Creating a Seed Checkpoint ```bash NGPU=1 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --checkpoint.enable --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1 ``` +#### Loading Seed Checkpoints for Debugging + +When using seed checkpoints for debugging or validation purposes, you can enable the `load_only` configuration to load checkpoints without saving any new ones during training. This is particularly useful when you only want to verify model correctness or compare different configurations without cluttering your disk: + +```bash +CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --checkpoint.enable --checkpoint.load_only +``` + +The `--checkpoint.load_only` flag prevents the training process from saving any checkpoints, allowing you to: +- Run debugging sessions without generating unwanted checkpoint files +- Compare model behaviors using the same initial weights without checkpoint overhead + **Note**: Using a seed checkpoint will only make sure a model has same initial weights when configs change, but the training process may not be the same even after setting the seed and the `deterministic` mode, e.g. due to tensor shape change, data precision change, usage of randomness in model code, etc. ### Example: Reproducing loss curves with different parallelism configs diff --git a/tests/unit_tests/test_checkpoint.py b/tests/unit_tests/test_checkpoint.py index de02741b6..39c7e83d6 100644 --- a/tests/unit_tests/test_checkpoint.py +++ b/tests/unit_tests/test_checkpoint.py @@ -675,6 +675,64 @@ def __init__(self): manager.close() + @mock.patch("torch.distributed.get_rank", return_value=0) + @mock.patch("torchtitan.components.checkpoint.dcp.save") + def test_load_only_prevents_saving(self, mock_save, mock_rank): + """ + Test that load_only=True prevents checkpoint saving. + """ + mock_save.side_effect = self.fake_save + + # Configure load_only=True + cfg = self.job_config.checkpoint + cfg.load_only = True + cfg.interval = 1 # Set low interval to ensure saves would normally trigger + + manager = CheckpointManager( + dataloader=self.data_loader, + model_parts=self.model_parts, + optimizers=self.optimizers, + lr_schedulers=self.lr_schedulers, + states=self.states, + checkpoint_config=self.job_config.checkpoint, + sd_adapter=None, + base_folder=self.job_config.job.dump_folder, + ft_manager=self.ft_manager, + ) + + # Test various save conditions that would normally trigger saves + manager.save(curr_step=1) # Regular step save + self.assertEqual(mock_save.call_count, 0) + + manager.save(curr_step=5) # Interval-based save + self.assertEqual(mock_save.call_count, 0) + + manager.save(curr_step=10, last_step=True) # Last step save + self.assertEqual(mock_save.call_count, 0) + + manager.close() + + # Verify that saves work normally when load_only=False + mock_save.reset_mock() + cfg.load_only = False + + manager2 = CheckpointManager( + dataloader=self.data_loader, + model_parts=self.model_parts, + optimizers=self.optimizers, + lr_schedulers=self.lr_schedulers, + states=self.states, + checkpoint_config=self.job_config.checkpoint, + sd_adapter=None, + base_folder=self.job_config.job.dump_folder, + ft_manager=self.ft_manager, + ) + + manager2.save(curr_step=1) # Should trigger save now + self.assertEqual(mock_save.call_count, 1) + + manager2.close() + @mock.patch("torch.distributed.get_rank", return_value=0) @mock.patch("torchtitan.components.checkpoint.dcp.load") @mock.patch("torchtitan.components.checkpoint.dcp.save") diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 360e46212..1b25aa3f5 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -190,6 +190,7 @@ def __init__( ft_manager: FTManager | None = None, ) -> None: self.enable = checkpoint_config.enable + self.load_only = checkpoint_config.load_only self.ft_manager = ( ft_manager.manager if ft_manager and ft_manager.enabled else None @@ -761,7 +762,7 @@ def _save_last_step(self, curr_step: int) -> None: ) def _should_save(self, curr_step: int, last_step: bool = False) -> bool: - if not self.enable: + if not self.enable or self.load_only: return False if curr_step == 1 and self.enable_first_step_checkpoint: diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 304f84bda..eb477941c 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -527,6 +527,14 @@ class Checkpoint: Could be implemented as a separate script, but this way shares more code. """ + load_only: bool = False + """ + In certain scenarios, you may only need to load checkpoints for verification or debugging + purposes, without saving any new checkpoints. For example, you might use seed checkpoints + to validate model correctness. Enabling this option allows checkpoints to be loaded + without saving any during the training. + """ + @dataclass class ActivationCheckpoint: