Skip to content

Commit 3115424

Browse files
committed
Re-enable FSDP+TP w/ strided sharding
ghstack-source-id: 13d0d4c Pull Request resolved: #507
1 parent fa7fe1e commit 3115424

File tree

2 files changed

+33
-11
lines changed

2 files changed

+33
-11
lines changed

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
3030
from torchtitan.logging import logger
3131
from torchtitan.parallelisms.parallel_dims import ParallelDims
32+
from torchtitan.parallelisms.utils import check_strided_sharding_enabled
3233

3334

3435
def parallelize_llama(
@@ -313,17 +314,8 @@ def apply_fsdp(
313314
)
314315
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
315316

316-
if pp_enabled:
317-
# TODO
318-
# This PR https://github.com/pytorch/pytorch/pull/129519 added a safety check to avoid using 2D/3D DCP since
319-
# without strided sharding, DCP can not safely support resharding for 2D/3D. However, for PP to work, even
320-
# without resharding, we load a seed-checkpoint and need to disable the safety mechanism. This hack should be
321-
# removed after strided sharding is landed in DCP.
322-
for module in model.modules():
323-
assert len(module._load_state_dict_pre_hooks) <= 1
324-
module._load_state_dict_pre_hooks.clear()
325-
assert len(module._state_dict_pre_hooks) <= 1
326-
module._state_dict_pre_hooks.clear()
317+
# check if strided sharding is enabled, which is necessary for 2D/3D DCP
318+
check_strided_sharding_enabled()
327319

328320
logger.info("Applied FSDP to the model")
329321

torchtitan/parallelisms/utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import torch
7+
8+
from torchtitan.logging import logger
9+
10+
11+
def check_strided_sharding_enabled() -> None:
12+
# Correct 2D/3D DCP usage requires DTensor's strided sharding in PR
13+
# https://github.com/pytorch/pytorch/pull/130760. This function checks if users'
14+
# PyTorch nightly-build version is newer than 2024-08-09 to make sure this PR is
15+
# included when 2D/3D DCP is used.
16+
if "git" in torch.__version__: # pytorch is built from source
17+
# notify users to check if the commit hash is newer than 2024-08-09
18+
logger.warning(
19+
"detected that the pytorch is built from source. Please make sure the PR "
20+
"(https://github.com/pytorch/pytorch/pull/130760) is included in pytorch "
21+
"for correct 2D/3D DCP usage."
22+
)
23+
elif torch.__version__ < "2.5.0.dev20240809":
24+
# the nightly build pytorch was built before 2024-08-09
25+
logger.warning(
26+
f"detected that the pytorch version {torch.__version__} is older than "
27+
"2.5.0.dev20240809. Please upgrade a newer version to include the change "
28+
"made in https://github.com/pytorch/pytorch/pull/130760 for correct 2D/3D "
29+
"DCP usage."
30+
)

0 commit comments

Comments
 (0)