diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 32fbcc633d..f801072fb1 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -448,7 +448,7 @@ def apply_compile(model, job_config: JobConfig): def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig): """ - Apply data parallelism to the model. FSDP2 is used here. + Apply data parallelism (FSDP2) to the model. """ dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh @@ -461,21 +461,20 @@ def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig): fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} for layer_id, transformer_block in model.layers.items(): - # As an optimization, do not reshard after forward for the last - # transformer block since FSDP would prefetch it immediately. - # When using Pipeline Parallelism, generally zero-2 is best so as to avoid repeated reshardings - # per microbatch. - reshard_after_forward = ( - int(layer_id) < len(model.layers) - 1 and not parallel_dims.pp_enabled - ) + if parallel_dims.pp_enabled: + # For PP, do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + reshard_after_forward = False + else: + # As an optimization, do not reshard after forward for the last + # transformer block since FSDP would prefetch it immediately + reshard_after_forward = int(layer_id) < len(model.layers) - 1 fully_shard( transformer_block, **fsdp_config, reshard_after_forward=reshard_after_forward, ) - model.layers[layer_id] = transformer_block - - model = fully_shard( + fully_shard( model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled )