Skip to content

Commit e7c31be

Browse files
committed
Add Pipeline Parallel (and 2D PP+FSDP) support
runs PP+DP and PP+TP without issue, runs PP+TP+DP with decreasing loss, but fails DCP save Supports only simple schedules currently, gpipe and 1f1b. Ads cmdline/toml arg for specifiying split points, in a unified way between tracer or manual frontend. e.g. user can specifiy "layers.2,layers.4" as split points. Currently uses manual frontend by default, but allows specifying tracer frontend. Tracer frontend requires working around additional compatibility limitations, indicated by raising assertions, and is not ready for wider use yet. ghstack-source-id: d7e0a13 Pull Request resolved: #318
1 parent 99a73dd commit e7c31be

File tree

8 files changed

+440
-36
lines changed

8 files changed

+440
-36
lines changed

create_seed_checkpoint.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ LOG_RANK=0
2525
CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"}
2626

2727
seed_checkpoint="--checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint"
28-
force_1d="--training.data_parallel_degree 1 --training.tensor_parallel_degree 1 --training.pipeline_parallel_degree 1"
28+
force_1d="--training.data_parallel_degree 1 --training.tensor_parallel_degree 1 --experimental.pipeline_parallel_degree 1"
2929
overrides=""
3030
if [ $# -ne 0 ]; then
3131
overrides="$*"

test_runner.py

Lines changed: 106 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from dataclasses import dataclass
1212
from typing import Sequence
1313

14+
from torchtitan.logging_utils import logger
15+
1416
try:
1517
import tomllib
1618
except ModuleNotFoundError:
@@ -25,6 +27,8 @@ class OverrideDefinitions:
2527

2628
override_args: Sequence[Sequence[str]] = tuple(tuple(" "))
2729
test_descr: str = "default"
30+
requires_seed_checkpoint: bool = False
31+
ngpu: int = 4
2832

2933

3034
def build_test_list(args):
@@ -35,6 +39,78 @@ def build_test_list(args):
3539
"""
3640
integration_tests_flavors = defaultdict(list)
3741
integration_tests_flavors["debug_model.toml"] = [
42+
OverrideDefinitions(
43+
[
44+
[
45+
"--checkpoint.enable_checkpoint",
46+
f"--job.dump_folder {args.output_dir}/pp_1f1b/",
47+
"--experimental.pipeline_parallel_degree 2",
48+
"--experimental.pipeline_parallel_split_points layers.1",
49+
"--experimental.pipeline_parallel_schedule 1f1b",
50+
"--training.data_parallel_degree 1",
51+
],
52+
],
53+
"PP 1D test 1f1b",
54+
requires_seed_checkpoint=True,
55+
ngpu=2,
56+
),
57+
OverrideDefinitions(
58+
[
59+
[
60+
"--checkpoint.enable_checkpoint",
61+
f"--job.dump_folder {args.output_dir}/pp_gpipe/",
62+
"--experimental.pipeline_parallel_degree 2",
63+
"--experimental.pipeline_parallel_split_points layers.1",
64+
"--experimental.pipeline_parallel_schedule gpipe",
65+
"--training.data_parallel_degree 1",
66+
],
67+
],
68+
"PP 1D test gpipe",
69+
requires_seed_checkpoint=True,
70+
ngpu=2,
71+
),
72+
OverrideDefinitions(
73+
[
74+
[
75+
"--checkpoint.enable_checkpoint",
76+
f"--job.dump_folder {args.output_dir}/pp_dp_1f1b/",
77+
"--experimental.pipeline_parallel_degree 2",
78+
"--experimental.pipeline_parallel_split_points layers.1",
79+
"--experimental.pipeline_parallel_schedule 1f1b",
80+
"--training.data_parallel_degree 2",
81+
],
82+
],
83+
"PP+DP 1f1b 2D test",
84+
requires_seed_checkpoint=True,
85+
),
86+
OverrideDefinitions(
87+
[
88+
[
89+
"--checkpoint.enable_checkpoint",
90+
f"--job.dump_folder {args.output_dir}/pp_dp_gpipe/",
91+
"--experimental.pipeline_parallel_degree 2",
92+
"--experimental.pipeline_parallel_split_points layers.1",
93+
"--experimental.pipeline_parallel_schedule gpipe",
94+
"--training.data_parallel_degree 2",
95+
],
96+
],
97+
"PP+DP gpipe 2D test",
98+
requires_seed_checkpoint=True,
99+
),
100+
OverrideDefinitions(
101+
[
102+
[
103+
"--checkpoint.enable_checkpoint",
104+
f"--job.dump_folder {args.output_dir}/pp_tp/",
105+
"--experimental.pipeline_parallel_degree 2",
106+
"--experimental.pipeline_parallel_split_points layers.1",
107+
"--training.tensor_parallel_degree 2",
108+
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP
109+
],
110+
],
111+
"PP+TP 2D test",
112+
requires_seed_checkpoint=True,
113+
),
38114
OverrideDefinitions(
39115
[
40116
[
@@ -100,23 +176,43 @@ def build_test_list(args):
100176
return integration_tests_flavors
101177

102178

179+
def _run_cmd(cmd):
180+
return subprocess.run(
181+
[cmd],
182+
stdout=subprocess.PIPE,
183+
stderr=subprocess.STDOUT,
184+
text=True,
185+
shell=True,
186+
)
187+
188+
103189
def run_test(test_flavor: OverrideDefinitions, full_path: str):
104190
# run_test supports sequence of tests.
105191
for override_arg in test_flavor.override_args:
106-
cmd = f"CONFIG_FILE={full_path} NGPU=4 LOG_RANK=0,1,2,3 ./run_llama_train.sh"
192+
193+
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK=0,1,2,3 ./run_llama_train.sh"
107194
if override_arg:
108195
cmd += " " + " ".join(override_arg)
109-
print(
196+
logger.info(
110197
f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}====="
111198
)
112-
result = subprocess.run(
113-
[cmd],
114-
stdout=subprocess.PIPE,
115-
stderr=subprocess.STDOUT,
116-
text=True,
117-
shell=True,
118-
)
119-
print(result.stdout)
199+
200+
if test_flavor.requires_seed_checkpoint:
201+
dump_folder_arg = None
202+
for arg in override_arg:
203+
if "--job.dump_folder" in arg:
204+
dump_folder_arg = arg
205+
assert (
206+
dump_folder_arg is not None
207+
), "Can't use seed checkpoint if folder is not specified"
208+
logger.info("Creating seed checkpoint")
209+
result = _run_cmd(
210+
f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh {dump_folder_arg}"
211+
)
212+
logger.info(result.stdout)
213+
214+
result = _run_cmd(cmd)
215+
logger.info(result.stdout)
120216
if result.returncode != 0:
121217
raise Exception(
122218
f"Integration test failed, flavor : {test_flavor.test_descr}, command : {cmd}"

torchtitan/config_manager.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
}
2626

2727

28+
def string_list(raw_arg):
29+
return raw_arg.split(",")
30+
31+
2832
class JobConfig:
2933
"""
3034
A helper class to manage the train configuration.
@@ -210,10 +214,68 @@ def __init__(self):
210214
help="Whether to apply loss parallel when sequence parallel is enabled",
211215
)
212216
self.parser.add_argument(
213-
"--training.pipeline_parallel_degree",
217+
"--experimental.pipeline_parallel_degree",
214218
type=int,
215219
default=1,
216-
help="Pipeline Parallelism degree. 1 means disabled.",
220+
help="""
221+
Pipeline Parallelism degree, or number of ranks. 1 means disabled.
222+
If using looped schedules, this still specifies the number of physical ranks, not the number
223+
of stages. Stages per rank are inferred from split points degree, and schedule.""",
224+
)
225+
self.parser.add_argument(
226+
"--experimental.pipeline_parallel_split_points",
227+
type=string_list,
228+
nargs="+",
229+
default=[],
230+
help="""
231+
Specify comma-separated names of modules to use as the beginning of a split point.
232+
233+
e.g. "layers.0,layers.2" will cause the model to be split into 3 stages,
234+
the first containing all the layers up to layers.0,
235+
the second containing layers.0 and up to layers.2,
236+
the third containing layers.2 and all the remaining layers.
237+
238+
Note: fully-automated splitting may be enabled in the future,
239+
but currently the split points must be specified manually for both manual and tracer.""",
240+
)
241+
self.parser.add_argument(
242+
"--experimental.pipeline_parallel_schedule",
243+
type=str,
244+
choices=["1f1b", "gpipe"],
245+
default="1f1b",
246+
help="""
247+
Specify the Pipeline Parallel schedule to use.
248+
249+
The schedule must be compatible with the split points and stages_per_rank.
250+
251+
Looped schedules are not yet supported in torchtitan.""",
252+
)
253+
self.parser.add_argument(
254+
"--experimental.pipeline_parallel_split_mode",
255+
type=str,
256+
choices=["manual", "tracer"],
257+
default="manual",
258+
help="""
259+
Specify the split method (e.g. the Pipeline Parallelism Front End)
260+
261+
"manual" means each rank will construct an nn.Module with the appropriate layers and .forward
262+
implementation manually, and then wrap it in a PipelineStage.
263+
264+
"tracer" means the full model will be initialized (via meta device) and then traced into a graph,
265+
split via the provided split points, unflattened into an nn.Module,
266+
and finally wrapped in a PipelineStage. tracer frontend is currently more experimental.""",
267+
)
268+
self.parser.add_argument(
269+
"--experimental.pipeline_parallel_microbatches",
270+
type=int,
271+
default=None,
272+
help="""
273+
How many microbatches to split the global training batch into when using pipeline parallelism.
274+
275+
The global training batch size must be evenly divisible by the number of microbatches.
276+
277+
The default value will be the number of pipeline stages, if unspecified.
278+
""",
217279
)
218280
self.parser.add_argument(
219281
"--training.mixed_precision_param",
@@ -437,6 +499,11 @@ def parse_args_from_command_line(
437499
aux_parser.add_argument(
438500
"--" + arg, action="store_true" if val else "store_false"
439501
)
502+
elif arg == "experimental.pipeline_parallel_split_points":
503+
# without this special case, type inference breaks here,
504+
# since the inferred type is just 'list' and it ends up flattening
505+
# e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]
506+
aux_parser.add_argument("--" + arg, type=string_list)
440507
else:
441508
aux_parser.add_argument("--" + arg, type=type(val))
442509

torchtitan/parallelisms/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,16 @@
99

1010
from torch.distributed.device_mesh import init_device_mesh
1111
from torchtitan.logging_utils import logger
12-
from torchtitan.parallelisms.parallelize_llama import parallelize_llama
12+
from torchtitan.parallelisms.parallelize_llama import parallelize_llama, pipeline_llama
1313

1414
models_parallelize_fns = {
1515
"llama2": parallelize_llama,
1616
"llama3": parallelize_llama,
1717
}
18+
models_pipelining_fns = {
19+
"llama2": pipeline_llama,
20+
"llama3": pipeline_llama,
21+
}
1822

1923

2024
@dataclass

0 commit comments

Comments
 (0)