diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 1de3c82c9..c8526270b 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -187,7 +187,20 @@ def __init__(self): "--training.data_parallel_degree", type=int, default=-1, - help="Data Parallelism degree. -1 means leftover ranks will be used (After SP/PP). 1 means disabled.", + help="Data Parallelism degree (FSDP). -1 means leftover ranks will be used (After SP/PP/replicate). 1 means disabled.", + ) + self.parser.add_argument( + "--training.data_parallel_replicate_degree", + type=int, + default=1, + help=""" + Data Parallelism with parameters being replicated degree. 1 means disabled. + If data_parallel_degree is > 1 and data_parallel_replicate_degree > 1, + the parallelism is HSDP. HSDP is not yet neabled and but will be supported soon. + When data_parallel_degree is -1 and data_parallel_replicate_degree > 1, + the parallelism is DDP. DDP should only be used for small model as + DDP + TP is not yet supported. + """, ) self.parser.add_argument( "--training.tensor_parallel_degree", @@ -210,7 +223,16 @@ def __init__(self): self.parser.add_argument( "--training.compile", action="store_true", - help="Whether to compile the model", + help="Whether to compile the model.", + ) + self.parser.add_argument( + "--training.compiled_autograd", + action="store_true", + help=""" + Whether to use CompiledAutograd to trace the backward. + This is an experimental feature and should not be used + unless you are familiar with CompiledAutograd. + """, ) self.parser.add_argument( "--training.fp8_linear", diff --git a/torchtitan/parallelisms/__init__.py b/torchtitan/parallelisms/__init__.py index e791b832a..e42b543c5 100644 --- a/torchtitan/parallelisms/__init__.py +++ b/torchtitan/parallelisms/__init__.py @@ -20,6 +20,7 @@ @dataclass class ParallelDims: dp: int + dp_replicate: int tp: int pp: int world_size: int @@ -29,21 +30,27 @@ def __post_init__(self): self._validate() def _validate(self): - dp, tp, pp = self.dp, self.tp, self.pp + dp, dp_replicate, tp, pp = self.dp, self.dp_replicate, self.tp, self.pp if dp == -1: - self.dp = dp = self.world_size // (tp * pp) + self.dp = dp = self.world_size // (dp_replicate * tp * pp) assert dp >= 1, dp + assert dp_replicate >= 1, dp_replicate assert tp >= 1, tp assert pp >= 1, pp assert ( - dp * tp * pp == self.world_size - ), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" + dp * dp_replicate * tp * pp == self.world_size + ), ( + f"Invalid parallel dims: dp({dp}) * dp_replicate({dp_replicate}) * " + f"tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})." + ) def build_mesh(self, device_type): dims = [] names = [] for d, name in zip( - [self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True + [self.pp, self.dp_replicate, self.dp, self.tp], + ["pp", "dp_replicate", "dp", "tp"], + strict=True ): if d > 1: dims.append(d) @@ -56,6 +63,10 @@ def build_mesh(self, device_type): def dp_enabled(self): return self.dp > 1 + @property + def dp_replicate_enabled(self): + return self.dp_replicate > 1 + @property def tp_enabled(self): return self.tp > 1 diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 779be60e8..6671ef4f5 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -11,8 +11,10 @@ from typing import Tuple import torch +import torch.nn as nn from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy +from torch.distributed._composable.replicate import replicate from torch.distributed._tensor import Replicate, Shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper as ptd_checkpoint_wrapper, @@ -129,7 +131,56 @@ def get_tp_parallel_strategy( return RowwiseParallel, ColwiseParallel -def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): +def maybe_enable_activation_checkpoint( + model: nn.Module, job_config: JobConfig +) -> nn.Module: + config = job_config.activation_checkpoint + ac_mode = config.mode + if ac_mode in ("full", "selective"): + for layer_id, transformer_block in enumerate(model.layers): + model.layers[layer_id] = checkpoint_wrapper(transformer_block, config) + logger.info(f"Applied {ac_mode} activation checkpointing to the model") + + return model + + +def enable_fsdp(model: nn.Module, dp_mesh, job_config: JobConfig) -> nn.Module: + # TODO: Expose `reduce_dtype` as a config option. + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, reduce_dtype=torch.float32 + ) + fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + for layer_id, transformer_block in enumerate(model.layers): + # As an optimization, do not reshard after forward for the last + # transformer block since FSDP would prefetch it immediately + reshard_after_forward = layer_id < len(model.layers) - 1 + fully_shard( + transformer_block, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + model.layers[layer_id] = transformer_block + model = fully_shard(model, **fsdp_config) + logger.info("Applied FSDP to the model") + + return model + + +def enable_ddp(model: nn.Module, dp_mesh, job_config: JobConfig) -> nn.Module: + if job_config.training.compile: + if job_config.training.compiled_autograd: + torch._dynamo.config.optimize_ddp = "python_reducer" + else: + torch._dynamo.config.optimize_ddp = "ddp_optimizer" + model = replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) + logger.info("Applied DDP to the model") + + return model + + +def parallelize_llama( + model: nn.Module, world_mesh, parallel_dims, job_config: JobConfig +) -> nn.Module: """ Apply parallelisms and activation checkpointing to the model. @@ -144,6 +195,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): raise NotImplementedError( "fused_rmsnorm not yet compatible with TP. Please use layernorm or rmsnorm." ) + if parallel_dims.dp_replicate_enabled: + raise NotImplementedError("DDP/HSDP + TP are not supported yet.") tp_mesh = world_mesh["tp"] row_parallel_strategy, col_parallel_strategy = get_tp_parallel_strategy( @@ -206,32 +259,15 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): logger.info("Applied Tensor Parallelism to the model") + model = maybe_enable_activation_checkpoint(model, job_config) if parallel_dims.dp_enabled: + if parallel_dims.dp_replicate_enabled: + raise NotImplementedError("HSDP is not supported yet.") dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names - # TODO: Expose `reduce_dtype` as a config option. - mp_policy = MixedPrecisionPolicy( - param_dtype=torch.bfloat16, reduce_dtype=torch.float32 - ) - ac_mode = job_config.activation_checkpoint.mode - fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} - for layer_id, transformer_block in enumerate(model.layers): - if job_config.activation_checkpoint.mode in ("full", "selective"): - transformer_block = checkpoint_wrapper( - transformer_block, job_config.activation_checkpoint - ) - # As an optimization, do not reshard after forward for the last - # transformer block since FSDP would prefetch it immediately - reshard_after_forward = layer_id < len(model.layers) - 1 - fully_shard( - transformer_block, - **fsdp_config, - reshard_after_forward=reshard_after_forward, - ) - model.layers[layer_id] = transformer_block - model = fully_shard(model, **fsdp_config) - if ac_mode in ("full", "selective"): - logger.info(f"Applied {ac_mode} activation checkpointing to the model") - logger.info("Applied FSDP to the model") + model = enable_fsdp(model, dp_mesh, job_config) + elif parallel_dims.dp_replicate_enabled: + dp_mesh = world_mesh["dp_replicate"] if world_mesh.ndim > 1 else world_mesh + model = enable_ddp(model, dp_mesh, job_config) return model diff --git a/train.py b/train.py index 318c7174e..19f3b9250 100644 --- a/train.py +++ b/train.py @@ -121,6 +121,7 @@ def main(job_config: JobConfig): world_size = int(os.environ["WORLD_SIZE"]) parallel_dims = ParallelDims( dp=job_config.training.data_parallel_degree, + dp_replicate=job_config.training.data_parallel_replicate_degree, tp=job_config.training.tensor_parallel_degree, pp=job_config.training.pipeline_parallel_degree, world_size=world_size, @@ -303,10 +304,13 @@ def loss_fn(pred, labels): optimizer.zero_grad() # forward / backward - with loss_parallel_ctx(): - pred = model(input_ids) - loss = loss_fn(pred, labels) - loss.backward() + with torch._dynamo.utils.maybe_enable_compiled_autograd( + job_config.training.compiled_autograd + ): + with loss_parallel_ctx(): + pred = model(input_ids) + loss = loss_fn(pred, labels) + loss.backward() # clip gradients torch.nn.utils.clip_grad_norm_( diff --git a/train_configs/llama_1b.toml b/train_configs/llama_1b.toml new file mode 100644 index 000000000..5a5518468 --- /dev/null +++ b/train_configs/llama_1b.toml @@ -0,0 +1,40 @@ +# TorchTrain Config.toml +[job] +dump_folder = "./outputs" +description = "LLaMA 1B training" + +[profiling] +enable_profiling = true +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 10 +enable_tensorboard = true +save_tb_folder = "tb" + +[model] +name = "llama2" +flavor = "1B" +norm_type = "fused_rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] +tokenizer_path = "./torchtitan/datasets/tokenizer/tokenizer.model" + +[optimizer] +name = "AdamW" +lr = 1.5e-4 + +[training] +batch_size = 8 +seq_len = 1024 +warmup_steps = 200 # lr scheduler warm up +max_norm = 1.0 # grad norm clipping +steps = 1000 +data_parallel_degree = -1 +tensor_parallel_degree = 1 +pipeline_parallel_degree = 1 +fp8_linear = "" +compile = false +dataset = "c4" + +[activation_checkpoint] +mode = "none" # ['none', 'full', 'selective']