diff --git a/docs/debugging.md b/docs/debugging.md index f7758cbde5..fd436367e9 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -70,7 +70,7 @@ When debugging issues with multi-dimensional parallelism (combinations of FSDP, Set consistent random seeds across all parallelism dimensions: ```bash -CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --training.seed 42 +CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --debug.seed 42 ``` **Seed behavior with parallelism:** @@ -84,7 +84,7 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_tr Enable deterministic algorithms to ensure bit-for-bit reproducibility across runs: ```bash -CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --training.deterministic +CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --debug.deterministic ``` **What it does:** @@ -93,6 +93,19 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_tr - Sets deterministic workspace configuration for CuBLAS operations - **Note:** This will significantly reduce training performance but ensures exact reproducibility +Use `--debug.deterministic_warn_only` to only warn about (not stop running) kernel without deterministic implementation. + +### Activation Checkipointing Debugging ### + +The following debug configs are available for AC. + +`ac_preserve_rng_state` - if deterministic output compared to non-checkpointed passes is required, set to true. Results in stashing and restoring the RNG state during each checkpoint, may be slower. + +`ac_determinism_check` - A string specifying the determinism function + +`ac_debug` - capture ac debug information. Will be slower. + +See https://docs.pytorch.org/docs/stable/checkpoint.html for details. ### Seed-Checkpoint-based Reproducibility diff --git a/torchtitan/config/__init__.py b/torchtitan/config/__init__.py index ba2795a601..e70d7fb622 100644 --- a/torchtitan/config/__init__.py +++ b/torchtitan/config/__init__.py @@ -28,6 +28,7 @@ Quantize, Training, Validation, + Debug ) from .manager import ConfigManager @@ -49,4 +50,5 @@ "Profiling", "Training", "Validation", + "Debug" ] diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index d7e2752aea..62b82cfc5c 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -247,15 +247,6 @@ class Training: many temporary files. """ - seed: int | None = None - """Choose the base RNG seed used for training""" - - deterministic: bool = False - """Use deterministic algorithms wherever possible, may be slower""" - - debug_moe_force_load_balance: bool = False - """If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only.""" - @dataclass class Parallelism: @@ -884,6 +875,29 @@ def __post_init__(self): ), "validation steps must be positive or -1" +@dataclass +class Debug: + seed: int | None = None + """Choose the base RNG seed used for training""" + + deterministic: bool = False + """Use deterministic algorithms wherever possible, may be slower""" + + deterministic_warn_only: bool = False + """Only warns about ops without deterministic implementations rather than erroring out """ + + ac_preserve_rng_state: bool = False + """If deterministic output compared to non-checkpointed passes is required, set to true. Results in stashing and restoring the RNG state during each checkpoint, may be slower. See https://docs.pytorch.org/docs/stable/checkpoint.html for details.""" + + ac_determinism_check: str = "default" + """A string specifying the determinism function. See https://docs.pytorch.org/docs/stable/checkpoint.html for details.""" + + ac_debug: bool = False + """ Capture ac debug information. Will be slower. See https://docs.pytorch.org/docs/stable/checkpoint.html for details.""" + + moe_force_load_balance: bool = False + """If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only.""" + @dataclass class JobConfig: """ @@ -909,6 +923,7 @@ class JobConfig: fault_tolerance: FaultTolerance = field(default_factory=FaultTolerance) experimental: Experimental = field(default_factory=Experimental) validation: Validation = field(default_factory=Validation) + debug: Debug = field(default_factory=Debug) def to_dict(self) -> dict[str, Any]: return asdict(self) diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index 57809c45f9..131c7f7f58 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -17,13 +17,14 @@ ) from torchtitan.config.job_config import ActivationCheckpoint as ACConfig +from torchtitan.config.job_config import Debug as DebugConfig from torchtitan.tools.logging import logger, warn_once _layer_sac_count = 0 -def _apply_layer_sac(module: nn.Module, ac_config: ACConfig) -> nn.Module: +def _apply_layer_sac(module: nn.Module, ac_config: ACConfig, debug_config:DebugConfig) -> nn.Module: """Apply layer selective activation checkpointing to the module. Args: @@ -38,7 +39,11 @@ def _apply_layer_sac(module: nn.Module, ac_config: ACConfig) -> nn.Module: ac_freq = int(ac_config.selective_ac_option) if not ac_freq or _layer_sac_count % ac_freq == 0: return ptd_checkpoint_wrapper( - module, preserve_rng_state=False, early_stop=ac_config.early_stop + module, + preserve_rng_state=debug_config.ac_preserve_rng_state, + determinism_check=debug_config.ac_determinism_check, + early_stop=ac_config.early_stop, + debug=debug_config.ac_debug ) else: return module @@ -123,11 +128,13 @@ def selective_checkpointing_context_fn(): return create_selective_checkpoint_contexts(_get_custom_policy(meta)) return ptd_checkpoint_wrapper( - module, - context_fn=selective_checkpointing_context_fn, - preserve_rng_state=False, - early_stop=ac_config.early_stop, - ) + module, + context_fn=selective_checkpointing_context_fn, + preserve_rng_state=dbg_config.ac_preserve_rng_state, + determinism_check=dbg_config.ac_determinism_check, + early_stop=ac_config.early_stop, + debug=dbg_config.ac_debug + ) def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module: @@ -143,6 +150,13 @@ def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module: return ptd_checkpoint_wrapper( module, preserve_rng_state=False, early_stop=ac_config.early_stop ) + return ptd_checkpoint_wrapper( + module, + preserve_rng_state=dbg_config.ac_preserve_rng_state, + determinism_check=dbg_config.ac_determinism_check, + early_stop=ac_config.early_stop, + debug=dbg_config.ac_debug + ) def _apply_op_sac_to_transformer_block_with_flex( diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index c2ec7bd777..296a2ebc6c 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -18,6 +18,7 @@ from torch.distributed.tensor import DTensor from torchtitan.config import Comm as CommConfig, TORCH_DTYPE_MAP +from torchtitan.config import Debug as DebugConfig from torchtitan.distributed.parallel_dims import ParallelDims from torchtitan.tools.logging import logger from torchtitan.tools.utils import device_module, device_type @@ -83,8 +84,7 @@ def dist_mean( def set_determinism( world_mesh: DeviceMesh | None, device: torch.device, - seed: int | None = None, - deterministic: bool = False, + debug_config: DebugConfig, distinct_seed_mesh_dim: str = "pp", ) -> None: """ @@ -97,9 +97,10 @@ def set_determinism( Set Determinism flags for increased reproducibility with loss of performance. """ - if deterministic: + if debug_config.deterministic: logger.info("Deterministic algorithm enabled (expect perf degradation).") torch.use_deterministic_algorithms(True) + torch.use_deterministic_algorithms(True, warn_only=debug_config.deterministic_warn_only) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # env var for deterministic CuBLAS @@ -114,6 +115,7 @@ def set_determinism( FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention) + seed = debug_config.seed if not world_mesh: if seed is not None: torch.manual_seed(seed) diff --git a/torchtitan/experiments/flux/train.py b/torchtitan/experiments/flux/train.py index 624792e83e..33018f95fc 100644 --- a/torchtitan/experiments/flux/train.py +++ b/torchtitan/experiments/flux/train.py @@ -35,8 +35,7 @@ def __init__(self, job_config: JobConfig): dist_utils.set_determinism( self.parallel_dims.world_mesh, self.device, - job_config.training.seed, - job_config.training.deterministic, + job_config.debug, distinct_seed_mesh_dim="dp_shard", ) diff --git a/torchtitan/experiments/forge/engine.py b/torchtitan/experiments/forge/engine.py index f8b1412959..1d3a420b9d 100644 --- a/torchtitan/experiments/forge/engine.py +++ b/torchtitan/experiments/forge/engine.py @@ -104,8 +104,7 @@ def __init__(self, job_config: ForgeJobConfig): dist_utils.set_determinism( world_mesh, self.device, - job_config.training.seed, - job_config.training.deterministic, + job_config.debug, ) self.train_spec = get_train_spec(job_config.model.name) diff --git a/torchtitan/experiments/forge/job_config.py b/torchtitan/experiments/forge/job_config.py index b1c014cc1d..d255bc0b72 100644 --- a/torchtitan/experiments/forge/job_config.py +++ b/torchtitan/experiments/forge/job_config.py @@ -20,6 +20,7 @@ Parallelism, Quantize, Training, + Debug, ) @@ -45,6 +46,7 @@ class ForgeJobConfig: # fault_tolerance: FaultTolerance = field(default_factory=FaultTolerance) # experimental: Experimental = field(default_factory=Experimental) # validation: Validation = field(default_factory=Validation) + debug: Debug = field(default_factory=Debug) def to_dict(self) -> dict[str, Any]: return asdict(self) diff --git a/torchtitan/experiments/llama4/model/args.py b/torchtitan/experiments/llama4/model/args.py index faeb60aadf..53043a1d02 100644 --- a/torchtitan/experiments/llama4/model/args.py +++ b/torchtitan/experiments/llama4/model/args.py @@ -82,7 +82,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.moe_args._debug_force_load_balance = ( - job_config.training.debug_moe_force_load_balance + job_config.debug.moe_force_load_balance ) def get_nparams_and_flops( diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index bf963a5b5f..2309c844a4 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -90,6 +90,7 @@ def __init__(self) -> None: SDPBackend.CUDNN_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.MATH ] def forward( diff --git a/torchtitan/train.py b/torchtitan/train.py index 6441ff0b2d..d94fa5ebc9 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -11,6 +11,7 @@ from typing import Any, Generator, Iterable, Optional import torch + from torch.distributed.elastic.multiprocessing.errors import record import torchtitan.protocols.train_spec as train_spec_module @@ -125,8 +126,7 @@ def __init__(self, job_config: JobConfig): dist_utils.set_determinism( world_mesh, self.device, - job_config.training.seed, - job_config.training.deterministic, + job_config.debug, ) self.train_spec = train_spec_module.get_train_spec(job_config.model.name)