Skip to content

Commit 41fb425

Browse files
committed
Add Pipeline Parallel (and 2D PP+FSDP) support
runs PP+DP and PP+TP without issue, runs PP+TP+DP with decreasing loss, but fails DCP save Supports only simple schedules currently, gpipe and 1f1b. Ads cmdline/toml arg for specifiying split points, in a unified way between tracer or manual frontend. e.g. user can specifiy "layers.2,layers.4" as split points. Currently uses manual frontend by default, but allows specifying tracer frontend. Tracer frontend requires working around additional compatibility limitations, indicated by raising assertions, and is not ready for wider use yet. ghstack-source-id: 9dd9b7a Pull Request resolved: #318
1 parent ef0a26d commit 41fb425

File tree

8 files changed

+451
-34
lines changed

8 files changed

+451
-34
lines changed

.github/workflows/unit_test_4gpu.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ jobs:
3636
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
3737
python -m pip install -r requirements.txt
3838
python -m pip install -r dev-requirements.txt
39+
python -m pip install git+https://github.com/pytorch/pippy
3940
- name: Run test_runner.py
4041
run: python ./test_runner.py
4142
- name: Upload Coverage to Codecov

test_runner.py

Lines changed: 89 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ class OverrideDefinitions:
2626

2727
override_args: Sequence[Sequence[str]] = tuple(tuple(" "))
2828
test_descr: str = "default"
29+
requires_seed_checkpoint: bool = False
30+
ngpu: int = 4
2931

3032

3133
CONFIG_DIR = "./train_configs"
@@ -85,25 +87,104 @@ class OverrideDefinitions:
8587
],
8688
"Checkpoint Integration Test - Save Model Weights Only bf16",
8789
),
90+
OverrideDefinitions(
91+
[
92+
[
93+
"--checkpoint.enable_checkpoint",
94+
f"--checkpoint.folder {test_checkpoint_dir}_pp",
95+
"--experimental.pipeline_parallel_degree 2",
96+
"--experimental.pipeline_parallel_split_points layers.1",
97+
"--training.data_parallel_degree 1",
98+
"--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue
99+
],
100+
],
101+
"PP 1D test",
102+
requires_seed_checkpoint=True,
103+
ngpu=2,
104+
),
105+
OverrideDefinitions(
106+
[
107+
[
108+
"--checkpoint.enable_checkpoint",
109+
f"--checkpoint.folder {test_checkpoint_dir}_pp_dp",
110+
"--experimental.pipeline_parallel_degree 2",
111+
"--experimental.pipeline_parallel_split_points layers.1",
112+
"--training.data_parallel_degree 2",
113+
"--model.norm_type fused_rmsnorm",
114+
],
115+
],
116+
"PP+DP 2D test",
117+
requires_seed_checkpoint=True,
118+
),
119+
OverrideDefinitions(
120+
[
121+
[
122+
"--checkpoint.enable_checkpoint",
123+
f"--checkpoint.folder {test_checkpoint_dir}_pp_tp",
124+
"--experimental.pipeline_parallel_degree 2",
125+
"--experimental.pipeline_parallel_split_points layers.1",
126+
"--training.tensor_parallel_degree 2",
127+
"--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue
128+
],
129+
],
130+
"PP+TP 2D test",
131+
requires_seed_checkpoint=True,
132+
),
133+
# oh.. not enough GPUs?
134+
# OverrideDefinitions(
135+
# [
136+
# [
137+
# "--checkpoint.enable_checkpoint",
138+
# f"--checkpoint.folder {test_checkpoint_dir}_pp_dp_tp",
139+
# "--experimental.pipeline_parallel_degree 2",
140+
# "--experimental.pipeline_parallel_split_points layers.1",
141+
# "--training.data_parallel_degree 2",
142+
# "--training.tensor_parallel_degree 2",
143+
# "--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue
144+
# ],
145+
# ],
146+
# "PP+DP+TP 3D test",
147+
# requires_seed_checkpoint=True,
148+
# ),
88149
]
89150

90151

152+
def _run_cmd(cmd):
153+
return subprocess.run(
154+
[cmd],
155+
stdout=subprocess.PIPE,
156+
stderr=subprocess.STDOUT,
157+
text=True,
158+
shell=True,
159+
)
160+
161+
91162
def run_test(test_flavor: OverrideDefinitions, full_path: str):
92163
# run_test supports sequence of tests.
93164
for override_arg in test_flavor.override_args:
94-
cmd = f"CONFIG_FILE={full_path} NGPU=4 LOG_RANK=0,1,2,3 ./run_llama_train.sh"
165+
166+
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK=0,1,2,3 ./run_llama_train.sh"
95167
if override_arg:
96168
cmd += " " + " ".join(override_arg)
97169
print(
98170
f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}====="
99171
)
100-
result = subprocess.run(
101-
[cmd],
102-
stdout=subprocess.PIPE,
103-
stderr=subprocess.STDOUT,
104-
text=True,
105-
shell=True,
106-
)
172+
173+
if test_flavor.requires_seed_checkpoint:
174+
checkpoint_folder_arg = None
175+
for arg in override_arg:
176+
if "--checkpoint.folder" in arg:
177+
checkpoint_folder_arg = arg
178+
assert (
179+
checkpoint_folder_arg is not None
180+
), "Can't use seed checkpoint if folder is not specified"
181+
print("Creating seed checkpoint")
182+
result = _run_cmd(
183+
f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh {checkpoint_folder_arg}"
184+
)
185+
print(result.stdout)
186+
187+
result = _run_cmd(cmd)
107188
print(result.stdout)
108189
if result.returncode != 0:
109190
raise Exception(

torchtitan/config_manager.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
from torchtitan.logging_utils import logger
1818

1919

20+
def string_list(raw_arg):
21+
return raw_arg.split(",")
22+
23+
2024
class JobConfig:
2125
"""
2226
A helper class to manage the train configuration.
@@ -202,10 +206,56 @@ def __init__(self):
202206
help="Whether to apply loss parallel when sequence parallel is enabled",
203207
)
204208
self.parser.add_argument(
205-
"--training.pipeline_parallel_degree",
209+
"--experimental.pipeline_parallel_degree",
206210
type=int,
207211
default=1,
208-
help="Pipeline Parallelism degree. 1 means disabled.",
212+
help="""
213+
Pipeline Parallelism degree, or number of ranks. 1 means disabled.
214+
If using looped schedules, this still specifies the number of physical ranks, not the number
215+
of stages. Stages per rank are inferred from split points degree, and schedule.""",
216+
)
217+
self.parser.add_argument(
218+
"--experimental.pipeline_parallel_split_points",
219+
type=string_list,
220+
nargs="+",
221+
default=[],
222+
help="""
223+
Specify comma-separated names of modules to use as the beginning of a split point.
224+
225+
e.g. "layers.0,layers.2" will cause the model to be split into 3 stages,
226+
the first containing all the layers up to layers.0,
227+
the second containing layers.0 and up to layers.2,
228+
the third containing layers.2 and all the remaining layers.
229+
230+
Note: fully-automated splitting may be enabled in the future,
231+
but currently the split points must be specified manually for both manual and tracer.""",
232+
)
233+
self.parser.add_argument(
234+
"--experimental.pipeline_parallel_schedule",
235+
type=str,
236+
choices=["1f1b", "gpipe"],
237+
default="1f1b",
238+
help="""
239+
Specify the Pipeline Parallel schedule to use.
240+
241+
The schedule must be compatible with the split points and stages_per_rank.
242+
243+
Looped schedules are not yet supported in torchtitan.""",
244+
)
245+
self.parser.add_argument(
246+
"--experimental.pipeline_parallel_split_mode",
247+
type=str,
248+
choices=["manual", "tracer"],
249+
default="manual",
250+
help="""
251+
Specify the split method (e.g. the Pipeline Parallelism Front End)
252+
253+
"manual" means each rank will construct an nn.Module with the appropriate layers and .forward
254+
implementation manually, and then wrap it in a PipelineStage.
255+
256+
"tracer" means the full model will be initialized (via meta device) and then traced into a graph,
257+
split via the provided split points, unflattened into an nn.Module,
258+
and finally wrapped in a PipelineStage. tracer frontend is currently more experimental.""",
209259
)
210260
self.parser.add_argument(
211261
"--training.compile",
@@ -408,6 +458,10 @@ def parse_args_from_command_line(
408458
aux_parser.add_argument(
409459
"--" + arg, action="store_true" if val else "store_false"
410460
)
461+
elif arg == "experimental.pipeline_parallel_split_points":
462+
# type inference breaks here, since the type is just 'list' and it ends up flattening
463+
# e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]
464+
aux_parser.add_argument("--" + arg, type=string_list)
411465
else:
412466
aux_parser.add_argument("--" + arg, type=type(val))
413467

torchtitan/parallelisms/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,16 @@
99

1010
from torch.distributed.device_mesh import init_device_mesh
1111
from torchtitan.logging_utils import logger
12-
from torchtitan.parallelisms.parallelize_llama import parallelize_llama
12+
from torchtitan.parallelisms.parallelize_llama import parallelize_llama, pipeline_llama
1313

1414
models_parallelize_fns = {
1515
"llama2": parallelize_llama,
1616
"llama3": parallelize_llama,
1717
}
18+
models_pipelining_fns = {
19+
"llama2": pipeline_llama,
20+
"llama3": pipeline_llama,
21+
}
1822

1923

2024
@dataclass

0 commit comments

Comments
 (0)