Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def apply_compile(model, job_config: JobConfig):

def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
Apply data parallelism to the model. FSDP2 is used here.
Apply data parallelism (FSDP2) to the model.
"""

dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
Expand All @@ -461,21 +461,20 @@ def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig):
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}

for layer_id, transformer_block in model.layers.items():
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately.
# When using Pipeline Parallelism, generally zero-2 is best so as to avoid repeated reshardings
# per microbatch.
reshard_after_forward = (
int(layer_id) < len(model.layers) - 1 and not parallel_dims.pp_enabled
)
if parallel_dims.pp_enabled:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

separate out the comments to make it clearer

# For PP, do not reshard after forward to avoid per-microbatch
# all-gathers, which can be expensive and non-overlapped
reshard_after_forward = False
else:
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = int(layer_id) < len(model.layers) - 1
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
model.layers[layer_id] = transformer_block
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not need to assign back in


model = fully_shard(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

fully_shard(
model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled
)

Expand Down