From 1581c6360849c67a8ea67ae60c72b1d56db5d8b8 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Wed, 10 Jul 2024 07:59:27 -0700 Subject: [PATCH] Made some stylistic changes to `apply_dp` [ghstack-poisoned] --- torchtitan/parallelisms/parallelize_llama.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index c07d4c3334..c91695cadf 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -458,23 +458,21 @@ def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig): reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], ) fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} - for layer_id, transformer_block in model.layers.items(): - # As an optimization, do not reshard after forward for the last - # transformer block since FSDP would prefetch it immediately. - # When using Pipeline Parallelism, generally zero-2 is best so as to avoid repeated reshardings - # per microbatch. - reshard_after_forward = ( - int(layer_id) < len(model.layers) - 1 and not parallel_dims.pp_enabled - ) + if parallel_dims.pp_enabled: + # For PP, do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + reshard_after_forward = False + else: + # As an optimization, do not reshard after forward for the last + # transformer block since FSDP would prefetch it immediately + reshard_after_forward = int(layer_id) < len(model.layers) - 1 fully_shard( transformer_block, **fsdp_config, reshard_after_forward=reshard_after_forward, ) - model.layers[layer_id] = transformer_block - - model = fully_shard( + fully_shard( model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled )