Skip to content

[do not review][example] train.py without pp #501

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 27 additions & 100 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,7 @@
from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
from torchtitan.parallelisms import (
build_pipeline_schedule,
models_parallelize_fns,
models_pipelining_fns,
ParallelDims,
)
from torchtitan.parallelisms import models_parallelize_fns, ParallelDims
from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling


Expand All @@ -51,6 +46,9 @@ def main(job_config: JobConfig):
init_logger()
logger.info(f"Starting job: {job_config.job.description}")

if job_config.experimental.pipeline_parallel_degree > 1:
raise RuntimeError("To use Pipeline Parallelism, please run train.py")

# used for colorful printing
color = utils.Color if job_config.metrics.enable_color_printing else utils.NoColor

Expand Down Expand Up @@ -82,9 +80,6 @@ def main(job_config: JobConfig):
else:
dp_degree, dp_rank = 1, 0

if parallel_dims.pp_enabled:
pp_mesh = world_mesh["pp"]

model_name = job_config.model.name

# build tokenizer
Expand Down Expand Up @@ -115,17 +110,17 @@ def main(job_config: JobConfig):

logger.info(f"Building {model_name} {job_config.model.flavor} with {model_config}")
with torch.device("meta"):
whole_model = model_cls.from_model_args(model_config)
model = model_cls.from_model_args(model_config)

# a no-op hander if float8 is not enabled
float8_handler = Float8Handler(job_config, parallel_dims)
# swap to Float8Linear based on float8 configs
float8_handler.convert_to_float8_training(whole_model)
float8_handler.convert_to_float8_training(model)

# log model size
model_param_count = utils.get_num_params(whole_model)
model_param_count = utils.get_num_params(model)
num_flop_per_token = utils.get_num_flop_per_token(
utils.get_num_params(whole_model, exclude_embedding=True),
utils.get_num_params(model, exclude_embedding=True),
model_config,
job_config.training.seq_len,
)
Expand All @@ -134,41 +129,10 @@ def main(job_config: JobConfig):
f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
)

if parallel_dims.pp_enabled:
stages, model_parts = models_pipelining_fns[model_name](
whole_model, pp_mesh, parallel_dims, job_config, device, model_config
)
else:
# In 1D/2D cases or PP with simple schedules, model_parts is just one item
# for PP with looped schedules, each item is one stage-model-chunk
# we iterate all model_parts for applying SPMD parallelism, compilation, optimizer, and checkpointing
model_parts = [whole_model]

# apply PT-D DP/TP parallelisms and activation checkpointing
model_parts = [
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
for m in model_parts
]

init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda"
for model in model_parts:
model.to_empty(device=init_device)

# loss fn can be shared by pipeline-parallel or non-pp execution
def loss_fn(pred, labels):
return torch.nn.functional.cross_entropy(
pred.flatten(0, 1), labels.flatten(0, 1)
)

if parallel_dims.pp_enabled:
pp_schedule = build_pipeline_schedule(
job_config, parallel_dims, stages, loss_fn
)
else:
# If PP is enabled, we can't rely on init_weights, because some layers are missing.
# In the future, we may make init_weights handle missing layers, but also have to consider RNG seed propagation.
# allocate sharded model on GPU and initialize weights via DTensor
whole_model.init_weights()
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)
model.to_empty(device="cuda")
model.init_weights()

gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
logger.info(
Expand All @@ -178,43 +142,26 @@ def loss_fn(pred, labels):
)

# build optimizer after applying parallelisms to the model
optimizers = build_optimizers(model_parts, job_config)
optimizers = build_optimizers([model], job_config)
lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config)

train_state = TrainState()

# train loop
for model in model_parts:
model.train()
model.train()

# load initial checkpoint
checkpoint = CheckpointManager(
dataloader=data_loader,
model_parts=model_parts,
model_parts=[model],
optimizers=optimizers.optimizers,
lr_schedulers=lr_schedulers.schedulers,
states={"train_state": train_state},
job_config=job_config,
)

if job_config.checkpoint.create_seed_checkpoint:
assert (
world_size == 1
), "Must create seed-checkpoint using one gpu, to disable sharding"
checkpoint.save(curr_step=0, force=True)
logger.info("Created seed checkpoint")
return

checkpoint_loaded = checkpoint.load()

if parallel_dims.pp_enabled and not checkpoint_loaded:
raise RuntimeError(
"Pipeline Parallelism requires meta-initialization and loading seed checkpoint. "
"Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`"
)

metric_logger = build_metric_logger(job_config, parallel_dims)

# plot losses loaded from checkpoint (if any) to TensorBoard
# NOTE: Loss info after the last log step before checkpoint saving will not be ploted.
# This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq
Expand Down Expand Up @@ -271,43 +218,23 @@ def loss_fn(pred, labels):
labels = labels.cuda()
optimizers.zero_grad()

if parallel_dims.pp_enabled:
# pipeline parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1

with train_context():
if pp_mesh.get_local_rank() == 0:
pp_schedule.step(input_ids)
elif is_last_stage:
losses = []
pp_schedule.step(target=labels, losses=losses)
else:
pp_schedule.step()

# accumulate losses across pipeline microbatches
loss = (
torch.mean(torch.stack(losses))
if is_last_stage
else torch.Tensor([-1.0])
with train_context():
pred = model(input_ids)
loss = torch.nn.functional.cross_entropy(
pred.flatten(0, 1), labels.flatten(0, 1)
)
else:
# Non-PP forward / backward
with train_context():
pred = model(input_ids)
loss = loss_fn(pred, labels)
# pred.shape=(bs, seq_len, vocab_size)
# need to free to before bwd to avoid peaking memory
del pred
loss.backward()
# pred.shape=(bs, seq_len, vocab_size)
# need to free to before bwd to avoid peaking memory
del pred
loss.backward()

# clip gradients
for model in model_parts:
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
)
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
)

# sync float8 amaxes and scales
float8_handler.sync_float8_amax_and_scale_history(model_parts)
float8_handler.sync_float8_amax_and_scale_history(model)

# optimizer step
checkpoint.maybe_wait_for_staging()
Expand All @@ -316,7 +243,7 @@ def loss_fn(pred, labels):

# calculate float8 dynamic amax/scale for all-parameter for FSDP2
# it issues a single all-reduce for all parameters at once for better performance
float8_handler.precompute_float8_dynamic_scale_for_fsdp(model_parts)
float8_handler.precompute_float8_dynamic_scale_for_fsdp(model)

losses_since_last_log.append(loss)

Expand Down
Loading