Skip to content

Commit 6a430d8

Browse files
committed
WIP integrate pippy's tracer frontend
- dcp load seems to work now - need to pull in schedule object ghstack-source-id: cbbb8c9 Pull Request resolved: #161
1 parent 00f899f commit 6a430d8

File tree

6 files changed

+69
-17
lines changed

6 files changed

+69
-17
lines changed

run_llama_train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ TRAINER_DIR=${1:-/home/$USER/local/torchtrain}
1111
# e.g.
1212
# LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh
1313

14-
NGPU=${NGPU:-"8"}
14+
NGPU=${NGPU:-"2"}
1515

1616
# by default log just rank 0 output,
1717
LOG_RANK=${LOG_RANK:-0}

torchtrain/meta_init.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,9 @@ def meta_to_real_init_fn(module: nn.Module):
4646
torch.randn_like(param, device=torch.device("cuda"))
4747
)
4848
setattr(submodule, param_name, materialized_param)
49+
for param_name, param in submodule.named_buffers(recurse=False):
50+
if param.is_meta:
51+
materialized_param = nn.Parameter(
52+
torch.randn_like(param, device=torch.device("cuda"))
53+
)
54+
setattr(submodule, param_name, materialized_param)

torchtrain/models/llama/model.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -334,13 +334,16 @@ def __init__(self, model_args: ModelArgs):
334334
self.model_args = model_args
335335
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
336336

337-
self.freqs_cis = precompute_freqs_cis(
338-
# Note that self.model_args.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation
339-
# of models is 4096.
340-
# Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training
341-
# or fine-tuning.
342-
self.model_args.dim // self.model_args.n_heads,
343-
self.model_args.max_seq_len * 2,
337+
self.register_buffer(
338+
"freqs_cis",
339+
precompute_freqs_cis(
340+
# Note that self.model_args.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation
341+
# of models is 4096.
342+
# Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training
343+
# or fine-tuning.
344+
self.model_args.dim // self.model_args.n_heads,
345+
self.model_args.max_seq_len * 2,
346+
),
344347
)
345348

346349
def forward(self, tokens: torch.Tensor):
@@ -355,7 +358,7 @@ def forward(self, tokens: torch.Tensor):
355358
"""
356359
_bsz, seqlen = tokens.shape
357360
h = self.tok_embeddings(tokens)
358-
self.freqs_cis = self.freqs_cis.to(h.device)
361+
# self.freqs_cis = self.freqs_cis.to(h.device)
359362
freqs_cis = self.freqs_cis[0:seqlen]
360363
return h, freqs_cis
361364

torchtrain/parallelisms/parallelize_llama.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Tuple
99

1010
import torch
11+
from pippy import annotate_split_points, Pipe, PipeSplitWrapper
1112
from torch.distributed._tensor import Replicate, Shard
1213

1314
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
@@ -143,7 +144,31 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
143144
"""
144145
# apply PTD parallelisms
145146
if parallel_dims.pp_enabled:
146-
raise NotImplementedError("PP not implemented yet.")
147+
pp_mesh = world_mesh["pp"]
148+
stage_idx = pp_mesh.get_local_rank()
149+
layers_per_rank = len(model.layers) // parallel_dims.pp
150+
for i in range(1, parallel_dims.pp):
151+
annotate_split_points(
152+
model,
153+
{
154+
f"layers.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING
155+
},
156+
)
157+
158+
# Get example input
159+
label_shape = input_shape = (8, 2048) # TODO
160+
input_ids = torch.randint(
161+
model.vocab_size, input_shape, dtype=torch.int64, device="meta"
162+
)
163+
labels = torch.randint(
164+
model.vocab_size, label_shape, dtype=torch.int64, device="meta"
165+
)
166+
print("input_ids: ", input_ids.shape, input_ids.dtype)
167+
print("labels: ", labels.shape, labels.dtype)
168+
169+
# Create a pipeline representation from the model
170+
pipe = Pipe.from_tracing(model, parallel_dims.pp, example_args=(input_ids,))
171+
model = pipe.get_stage_module(stage_idx)
147172

148173
# First we apply Tensor Parallelism if it's enabled
149174
if parallel_dims.tp_enabled:
@@ -256,10 +281,14 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
256281
meta_to_real_init_fn(model)
257282
model.cuda()
258283

259-
# TODO(whc) - proposal: remove this call, and assert that we always load a checkpoint
260-
# we have now moved from meta to device,
261-
# reset parameters for proper initialization
262-
model.reset_parameters()
263-
logger.info("Model fully initialized via reset_parameters")
284+
if parallel_dims.pp_enabled:
285+
setattr(pipe.split_gm, f"submod_{stage_idx}", model)
286+
return pipe
287+
else:
288+
# TODO(whc) - proposal: remove this call, and assert that we always load a checkpoint
289+
# we have now moved from meta to device,
290+
# reset parameters for proper initialization
291+
model.reset_parameters()
292+
logger.info("Model fully initialized via reset_parameters")
264293

265294
return model

train.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,18 @@ def main(job_config: JobConfig):
187187
model, world_mesh, parallel_dims, job_config
188188
)
189189

190+
# TODO(whc) everything below needs to become a function that can be applied to each 'virtual stage' of PP, if
191+
# there are virtual stages
192+
if parallel_dims.pp_enabled:
193+
pmod = model
194+
pp_mesh = world_mesh["pp"]
195+
pp_degree = pp_mesh.size()
196+
pp_rank = pp_mesh.get_local_rank()
197+
logger.info(
198+
f"{Color.blue}Extracting pipeline module for stage {pp_mesh.get_local_rank()}{Color.reset}"
199+
)
200+
model = pmod.get_stage_module(pp_mesh.get_local_rank())
201+
190202
# build optimizer after applying parallelisms to the model
191203
optimizer = build_optimizer(model, job_config)
192204
scheduler = get_lr_scheduler(optimizer, job_config)
@@ -258,10 +270,12 @@ def main(job_config: JobConfig):
258270

259271
input_ids = input_ids.cuda()
260272
labels = labels.cuda()
261-
273+
print("i", input_ids.shape)
274+
print("l", labels.shape)
262275
optimizer.zero_grad()
263276

264277
# forward
278+
# TODO - integrate pp batch splitter
265279
pred = model(input_ids)
266280

267281
with loss_parallel() if parallel_dims.loss_parallel_enabled else contextlib.nullcontext():

train_configs/debug_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ max_norm = 1.0 # grad norm clipping
3232
steps = 10
3333
data_parallel_degree = -1
3434
tensor_parallel_degree = 1
35-
pipeline_parallel_degree = 1
35+
pipeline_parallel_degree = 2
3636
fp8_linear = ""
3737
compile = false
3838
checkpoint_interval = 3600

0 commit comments

Comments
 (0)