Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/debugging.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,24 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_tr

For multiple experimental runs with different parallelism configs, we need to use a "seed" checkpoint to ensure model initializations are the same across runs. This is because in `torchtitan/train.py`, the model parameters are sharded first, and then have their weights initialized on each rank separately. As a result, it is not equivalent to initialize the model on one rank and then shard it. Using a seed checkpoint helps different runs load the same model weights from checkpoint -- DCP resharding will make sure the loaded weights are sharded correctly according to the parallelism configs.

#### Creating a Seed Checkpoint

```bash
NGPU=1 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --checkpoint.enable --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1
```

#### Loading Seed Checkpoints for Debugging

When using seed checkpoints for debugging or validation purposes, you can enable the `load_only` configuration to load checkpoints without saving any new ones during training. This is particularly useful when you only want to verify model correctness or compare different configurations without cluttering your disk:

```bash
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --checkpoint.enable --checkpoint.load_only
```

The `--checkpoint.load_only` flag prevents the training process from saving any checkpoints, allowing you to:
- Run debugging sessions without generating unwanted checkpoint files
- Compare model behaviors using the same initial weights without checkpoint overhead

**Note**: Using a seed checkpoint will only make sure a model has same initial weights when configs change, but the training process may not be the same even after setting the seed and the `deterministic` mode, e.g. due to tensor shape change, data precision change, usage of randomness in model code, etc.

### Example: Reproducing loss curves with different parallelism configs
Expand Down
58 changes: 58 additions & 0 deletions tests/unit_tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,64 @@ def __init__(self):

manager.close()

@mock.patch("torch.distributed.get_rank", return_value=0)
@mock.patch("torchtitan.components.checkpoint.dcp.save")
def test_load_only_prevents_saving(self, mock_save, mock_rank):
"""
Test that load_only=True prevents checkpoint saving.
"""
mock_save.side_effect = self.fake_save

# Configure load_only=True
cfg = self.job_config.checkpoint
cfg.load_only = True
cfg.interval = 1 # Set low interval to ensure saves would normally trigger

manager = CheckpointManager(
dataloader=self.data_loader,
model_parts=self.model_parts,
optimizers=self.optimizers,
lr_schedulers=self.lr_schedulers,
states=self.states,
checkpoint_config=self.job_config.checkpoint,
sd_adapter=None,
base_folder=self.job_config.job.dump_folder,
ft_manager=self.ft_manager,
)

# Test various save conditions that would normally trigger saves
manager.save(curr_step=1) # Regular step save
self.assertEqual(mock_save.call_count, 0)

manager.save(curr_step=5) # Interval-based save
self.assertEqual(mock_save.call_count, 0)

manager.save(curr_step=10, last_step=True) # Last step save
self.assertEqual(mock_save.call_count, 0)

manager.close()

# Verify that saves work normally when load_only=False
mock_save.reset_mock()
cfg.load_only = False

manager2 = CheckpointManager(
dataloader=self.data_loader,
model_parts=self.model_parts,
optimizers=self.optimizers,
lr_schedulers=self.lr_schedulers,
states=self.states,
checkpoint_config=self.job_config.checkpoint,
sd_adapter=None,
base_folder=self.job_config.job.dump_folder,
ft_manager=self.ft_manager,
)

manager2.save(curr_step=1) # Should trigger save now
self.assertEqual(mock_save.call_count, 1)

manager2.close()

@mock.patch("torch.distributed.get_rank", return_value=0)
@mock.patch("torchtitan.components.checkpoint.dcp.load")
@mock.patch("torchtitan.components.checkpoint.dcp.save")
Expand Down
3 changes: 2 additions & 1 deletion torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def __init__(
ft_manager: FTManager | None = None,
) -> None:
self.enable = checkpoint_config.enable
self.load_only = checkpoint_config.load_only

self.ft_manager = (
ft_manager.manager if ft_manager and ft_manager.enabled else None
Expand Down Expand Up @@ -761,7 +762,7 @@ def _save_last_step(self, curr_step: int) -> None:
)

def _should_save(self, curr_step: int, last_step: bool = False) -> bool:
if not self.enable:
if not self.enable or self.load_only:
return False

if curr_step == 1 and self.enable_first_step_checkpoint:
Expand Down
8 changes: 8 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,14 @@ class Checkpoint:
Could be implemented as a separate script, but this way shares more code.
"""

load_only: bool = False
"""
In certain scenarios, you may only need to load checkpoints for verification or debugging
purposes, without saving any new checkpoints. For example, you might use seed checkpoints
to validate model correctness. Enabling this option allows checkpoints to be loaded
without saving any during the training.
"""


@dataclass
class ActivationCheckpoint:
Expand Down