Skip to content

Commit bc51495

Browse files
committed
Update on "Enable CP"
This PR adds experimental flags and functions to enable context parallelism. We currently support on ly FSDP + CP and CP only. CP + TP is being tested. [ghstack-poisoned]
2 parents 3538828 + 6d1ced5 commit bc51495

File tree

3 files changed

+2
-4
lines changed

3 files changed

+2
-4
lines changed

torchtitan/models/llama/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def _precompute_freqs_cis(self) -> torch.Tensor:
410410
self.model_args.dim // self.model_args.n_heads,
411411
# Need to compute until at least the max token limit for generation
412412
# (use 2x max sequence length to be safe)
413-
self.model_args.max_seq_len * 2,
413+
self.model_args.max_seq_len,
414414
self.model_args.rope_theta,
415415
)
416416

torchtitan/parallelisms/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,6 @@ def build_mesh(self, device_type):
6161
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
6262
names = tuple(names)
6363
world_mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)
64-
if self.cp > 1 and self.dp > 1:
65-
world_mesh.create_view_dim(dims=("dp", "cp"), name="dp_cp")
6664
return world_mesh
6765

6866
@property

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ def apply_fsdp(
506506
if parallel_dims.cp_enabled:
507507
# Temporary solution to enable FSDP + CP
508508
if parallel_dims.dp_enabled:
509-
dp_mesh = world_mesh["dp_cp"]
509+
dp_mesh = world_mesh["dp", "cp"]._flatten()
510510
else:
511511
dp_mesh = world_mesh["cp"]
512512
else:

0 commit comments

Comments
 (0)