Skip to content

Commit 03d27ce

Browse files
committed
Update base for Update on "[BE] replace the extra DeviceMesh _flatten with mesh access"
**Summary** pytorch/pytorch#138945 fixes DeviceMesh access on flattened mesh which are constructed from more than 2 meshes. Refer to the fix PR for details if interested. In #592 we avoided this issue by calling `_flatten` instead of direct accessing the flattened mesh. We want to turn back to mesh access which is more straightforward since the fix has been merged in PyTorch. [ghstack-poisoned]
2 parents 53d0f69 + 2a785e9 commit 03d27ce

File tree

2 files changed

+0
-37
lines changed

2 files changed

+0
-37
lines changed

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
3535
from torchtitan.logging import logger
3636
from torchtitan.parallelisms.parallel_dims import ParallelDims
37-
from torchtitan.parallelisms.utils import check_strided_sharding_enabled
3837

3938

4039
def parallelize_llama(
@@ -330,12 +329,6 @@ def apply_fsdp(
330329
if cpu_offload:
331330
fsdp_config["offload_policy"] = CPUOffloadPolicy()
332331

333-
# TODO: remove this check once PyTorch 2.5 is released. We can safely assume
334-
# that users won't use a nightly build which is older than 20240809 by then.
335-
if tp_enabled:
336-
# check if strided sharding is enabled, which is necessary for 2D/3D DCP
337-
check_strided_sharding_enabled()
338-
339332
for layer_id, transformer_block in model.layers.items():
340333
if pp_enabled:
341334
# For PP, do not reshard after forward to avoid per-microbatch

torchtitan/parallelisms/utils.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

0 commit comments

Comments
 (0)