Skip to content

Commit 19cf33e

Browse files
committed
Add Pipeline Parallel (and 2D PP+FSDP) support
runs PP+DP and PP+TP without issue, runs PP+TP+DP with decreasing loss, but fails DCP save Supports only simple schedules currently, gpipe and 1f1b. Ads cmdline/toml arg for specifiying split points, in a unified way between tracer or manual frontend. e.g. user can specifiy "layers.2,layers.4" as split points. Currently uses manual frontend by default, but allows specifying tracer frontend. Tracer frontend requires working around additional compatibility limitations, indicated by raising assertions, and is not ready for wider use yet. ghstack-source-id: e07b70e Pull Request resolved: #318
1 parent 4567f56 commit 19cf33e

File tree

9 files changed

+466
-36
lines changed

9 files changed

+466
-36
lines changed

.github/workflows/unit_test_4gpu.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,5 @@ jobs:
3030
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
3131
python -m pip install -r requirements.txt
3232
python -m pip install -r dev-requirements.txt
33-
mkdir artifacts-to-be-uploaded
33+
python -m pip install git+https://github.com/pytorch/pippy
3434
python ./test_runner.py artifacts-to-be-uploaded

create_seed_checkpoint.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ LOG_RANK=0
2525
CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"}
2626

2727
seed_checkpoint="--checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint"
28-
force_1d="--training.data_parallel_degree 1 --training.tensor_parallel_degree 1 --training.pipeline_parallel_degree 1"
28+
force_1d="--training.data_parallel_degree 1 --training.tensor_parallel_degree 1 --experimental.pipeline_parallel_degree 1"
2929
overrides=""
3030
if [ $# -ne 0 ]; then
3131
overrides="$*"

test_runner.py

Lines changed: 89 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ class OverrideDefinitions:
3030

3131
override_args: Sequence[Sequence[str]] = tuple(tuple(" "))
3232
test_descr: str = "default"
33+
requires_seed_checkpoint: bool = False
34+
ngpu: int = 4
3335

3436

3537
CONFIG_DIR = "./train_configs"
@@ -102,25 +104,104 @@ class OverrideDefinitions:
102104
],
103105
"Checkpoint Integration Test - Save Model Weights Only bf16",
104106
),
107+
OverrideDefinitions(
108+
[
109+
[
110+
"--checkpoint.enable_checkpoint",
111+
f"--job.dump_folder {args.output_dir}/pp/",
112+
"--experimental.pipeline_parallel_degree 2",
113+
"--experimental.pipeline_parallel_split_points layers.1",
114+
"--training.data_parallel_degree 1",
115+
"--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue
116+
],
117+
],
118+
"PP 1D test",
119+
requires_seed_checkpoint=True,
120+
ngpu=2,
121+
),
122+
OverrideDefinitions(
123+
[
124+
[
125+
"--checkpoint.enable_checkpoint",
126+
f"--job.dump_folder {args.output_dir}/pp_dp/",
127+
"--experimental.pipeline_parallel_degree 2",
128+
"--experimental.pipeline_parallel_split_points layers.1",
129+
"--training.data_parallel_degree 2",
130+
"--model.norm_type fused_rmsnorm",
131+
],
132+
],
133+
"PP+DP 2D test",
134+
requires_seed_checkpoint=True,
135+
),
136+
OverrideDefinitions(
137+
[
138+
[
139+
"--checkpoint.enable_checkpoint",
140+
f"--job.dump_folder {args.output_dir}/pp_tp/",
141+
"--experimental.pipeline_parallel_degree 2",
142+
"--experimental.pipeline_parallel_split_points layers.1",
143+
"--training.tensor_parallel_degree 2",
144+
"--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue
145+
],
146+
],
147+
"PP+TP 2D test",
148+
requires_seed_checkpoint=True,
149+
),
150+
# oh.. not enough GPUs?
151+
# OverrideDefinitions(
152+
# [
153+
# [
154+
# "--checkpoint.enable_checkpoint",
155+
# f"--job.dump_folder {args.output_dir}/pp_dp_tp/",
156+
# "--experimental.pipeline_parallel_degree 2",
157+
# "--experimental.pipeline_parallel_split_points layers.1",
158+
# "--training.data_parallel_degree 2",
159+
# "--training.tensor_parallel_degree 2",
160+
# "--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue
161+
# ],
162+
# ],
163+
# "PP+DP+TP 3D test",
164+
# requires_seed_checkpoint=True,
165+
# ),
105166
]
106167

107168

169+
def _run_cmd(cmd):
170+
return subprocess.run(
171+
[cmd],
172+
stdout=subprocess.PIPE,
173+
stderr=subprocess.STDOUT,
174+
text=True,
175+
shell=True,
176+
)
177+
178+
108179
def run_test(test_flavor: OverrideDefinitions, full_path: str):
109180
# run_test supports sequence of tests.
110181
for override_arg in test_flavor.override_args:
111-
cmd = f"CONFIG_FILE={full_path} NGPU=4 LOG_RANK=0,1,2,3 ./run_llama_train.sh"
182+
183+
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK=0,1,2,3 ./run_llama_train.sh"
112184
if override_arg:
113185
cmd += " " + " ".join(override_arg)
114186
print(
115187
f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}====="
116188
)
117-
result = subprocess.run(
118-
[cmd],
119-
stdout=subprocess.PIPE,
120-
stderr=subprocess.STDOUT,
121-
text=True,
122-
shell=True,
123-
)
189+
190+
if test_flavor.requires_seed_checkpoint:
191+
dump_folder_arg = None
192+
for arg in override_arg:
193+
if "--job.dump_folder" in arg:
194+
dump_folder_arg = arg
195+
assert (
196+
dump_folder_arg is not None
197+
), "Can't use seed checkpoint if folder is not specified"
198+
print("Creating seed checkpoint")
199+
result = _run_cmd(
200+
f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh {dump_folder_arg}"
201+
)
202+
print(result.stdout)
203+
204+
result = _run_cmd(cmd)
124205
print(result.stdout)
125206
if result.returncode != 0:
126207
raise Exception(

torchtitan/config_manager.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
from torchtitan.logging_utils import logger
1818

1919

20+
def string_list(raw_arg):
21+
return raw_arg.split(",")
22+
23+
2024
class JobConfig:
2125
"""
2226
A helper class to manage the train configuration.
@@ -202,10 +206,56 @@ def __init__(self):
202206
help="Whether to apply loss parallel when sequence parallel is enabled",
203207
)
204208
self.parser.add_argument(
205-
"--training.pipeline_parallel_degree",
209+
"--experimental.pipeline_parallel_degree",
206210
type=int,
207211
default=1,
208-
help="Pipeline Parallelism degree. 1 means disabled.",
212+
help="""
213+
Pipeline Parallelism degree, or number of ranks. 1 means disabled.
214+
If using looped schedules, this still specifies the number of physical ranks, not the number
215+
of stages. Stages per rank are inferred from split points degree, and schedule.""",
216+
)
217+
self.parser.add_argument(
218+
"--experimental.pipeline_parallel_split_points",
219+
type=string_list,
220+
nargs="+",
221+
default=[],
222+
help="""
223+
Specify comma-separated names of modules to use as the beginning of a split point.
224+
225+
e.g. "layers.0,layers.2" will cause the model to be split into 3 stages,
226+
the first containing all the layers up to layers.0,
227+
the second containing layers.0 and up to layers.2,
228+
the third containing layers.2 and all the remaining layers.
229+
230+
Note: fully-automated splitting may be enabled in the future,
231+
but currently the split points must be specified manually for both manual and tracer.""",
232+
)
233+
self.parser.add_argument(
234+
"--experimental.pipeline_parallel_schedule",
235+
type=str,
236+
choices=["1f1b", "gpipe"],
237+
default="1f1b",
238+
help="""
239+
Specify the Pipeline Parallel schedule to use.
240+
241+
The schedule must be compatible with the split points and stages_per_rank.
242+
243+
Looped schedules are not yet supported in torchtitan.""",
244+
)
245+
self.parser.add_argument(
246+
"--experimental.pipeline_parallel_split_mode",
247+
type=str,
248+
choices=["manual", "tracer"],
249+
default="manual",
250+
help="""
251+
Specify the split method (e.g. the Pipeline Parallelism Front End)
252+
253+
"manual" means each rank will construct an nn.Module with the appropriate layers and .forward
254+
implementation manually, and then wrap it in a PipelineStage.
255+
256+
"tracer" means the full model will be initialized (via meta device) and then traced into a graph,
257+
split via the provided split points, unflattened into an nn.Module,
258+
and finally wrapped in a PipelineStage. tracer frontend is currently more experimental.""",
209259
)
210260
self.parser.add_argument(
211261
"--training.compile",
@@ -408,6 +458,10 @@ def parse_args_from_command_line(
408458
aux_parser.add_argument(
409459
"--" + arg, action="store_true" if val else "store_false"
410460
)
461+
elif arg == "experimental.pipeline_parallel_split_points":
462+
# type inference breaks here, since the type is just 'list' and it ends up flattening
463+
# e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]
464+
aux_parser.add_argument("--" + arg, type=string_list)
411465
else:
412466
aux_parser.add_argument("--" + arg, type=type(val))
413467

torchtitan/parallelisms/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,16 @@
99

1010
from torch.distributed.device_mesh import init_device_mesh
1111
from torchtitan.logging_utils import logger
12-
from torchtitan.parallelisms.parallelize_llama import parallelize_llama
12+
from torchtitan.parallelisms.parallelize_llama import parallelize_llama, pipeline_llama
1313

1414
models_parallelize_fns = {
1515
"llama2": parallelize_llama,
1616
"llama3": parallelize_llama,
1717
}
18+
models_pipelining_fns = {
19+
"llama2": pipeline_llama,
20+
"llama3": pipeline_llama,
21+
}
1822

1923

2024
@dataclass

0 commit comments

Comments
 (0)