From 403b68b402622536dd0dfa00f123d5e0c7bc3dfd Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Sat, 3 Aug 2024 22:29:16 -0700 Subject: [PATCH] [do not review][example] train.py without pp [ghstack-poisoned] --- train.py | 127 ++++++++++++------------------------------------------- 1 file changed, 27 insertions(+), 100 deletions(-) diff --git a/train.py b/train.py index 58d23307c..967f8f95b 100644 --- a/train.py +++ b/train.py @@ -21,12 +21,7 @@ from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config from torchtitan.optimizer import build_lr_schedulers, build_optimizers -from torchtitan.parallelisms import ( - build_pipeline_schedule, - models_parallelize_fns, - models_pipelining_fns, - ParallelDims, -) +from torchtitan.parallelisms import models_parallelize_fns, ParallelDims from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling @@ -51,6 +46,9 @@ def main(job_config: JobConfig): init_logger() logger.info(f"Starting job: {job_config.job.description}") + if job_config.experimental.pipeline_parallel_degree > 1: + raise RuntimeError("To use Pipeline Parallelism, please run train.py") + # used for colorful printing color = utils.Color if job_config.metrics.enable_color_printing else utils.NoColor @@ -82,9 +80,6 @@ def main(job_config: JobConfig): else: dp_degree, dp_rank = 1, 0 - if parallel_dims.pp_enabled: - pp_mesh = world_mesh["pp"] - model_name = job_config.model.name # build tokenizer @@ -115,17 +110,17 @@ def main(job_config: JobConfig): logger.info(f"Building {model_name} {job_config.model.flavor} with {model_config}") with torch.device("meta"): - whole_model = model_cls.from_model_args(model_config) + model = model_cls.from_model_args(model_config) # a no-op hander if float8 is not enabled float8_handler = Float8Handler(job_config, parallel_dims) # swap to Float8Linear based on float8 configs - float8_handler.convert_to_float8_training(whole_model) + float8_handler.convert_to_float8_training(model) # log model size - model_param_count = utils.get_num_params(whole_model) + model_param_count = utils.get_num_params(model) num_flop_per_token = utils.get_num_flop_per_token( - utils.get_num_params(whole_model, exclude_embedding=True), + utils.get_num_params(model, exclude_embedding=True), model_config, job_config.training.seq_len, ) @@ -134,41 +129,10 @@ def main(job_config: JobConfig): f"{color.red}size: {model_param_count:,} total parameters{color.reset}" ) - if parallel_dims.pp_enabled: - stages, model_parts = models_pipelining_fns[model_name]( - whole_model, pp_mesh, parallel_dims, job_config, device, model_config - ) - else: - # In 1D/2D cases or PP with simple schedules, model_parts is just one item - # for PP with looped schedules, each item is one stage-model-chunk - # we iterate all model_parts for applying SPMD parallelism, compilation, optimizer, and checkpointing - model_parts = [whole_model] - # apply PT-D DP/TP parallelisms and activation checkpointing - model_parts = [ - models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config) - for m in model_parts - ] - - init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda" - for model in model_parts: - model.to_empty(device=init_device) - - # loss fn can be shared by pipeline-parallel or non-pp execution - def loss_fn(pred, labels): - return torch.nn.functional.cross_entropy( - pred.flatten(0, 1), labels.flatten(0, 1) - ) - - if parallel_dims.pp_enabled: - pp_schedule = build_pipeline_schedule( - job_config, parallel_dims, stages, loss_fn - ) - else: - # If PP is enabled, we can't rely on init_weights, because some layers are missing. - # In the future, we may make init_weights handle missing layers, but also have to consider RNG seed propagation. - # allocate sharded model on GPU and initialize weights via DTensor - whole_model.init_weights() + models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config) + model.to_empty(device="cuda") + model.init_weights() gpu_mem_stats = gpu_memory_monitor.get_peak_stats() logger.info( @@ -178,43 +142,26 @@ def loss_fn(pred, labels): ) # build optimizer after applying parallelisms to the model - optimizers = build_optimizers(model_parts, job_config) + optimizers = build_optimizers([model], job_config) lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config) train_state = TrainState() # train loop - for model in model_parts: - model.train() + model.train() # load initial checkpoint checkpoint = CheckpointManager( dataloader=data_loader, - model_parts=model_parts, + model_parts=[model], optimizers=optimizers.optimizers, lr_schedulers=lr_schedulers.schedulers, states={"train_state": train_state}, job_config=job_config, ) - - if job_config.checkpoint.create_seed_checkpoint: - assert ( - world_size == 1 - ), "Must create seed-checkpoint using one gpu, to disable sharding" - checkpoint.save(curr_step=0, force=True) - logger.info("Created seed checkpoint") - return - checkpoint_loaded = checkpoint.load() - if parallel_dims.pp_enabled and not checkpoint_loaded: - raise RuntimeError( - "Pipeline Parallelism requires meta-initialization and loading seed checkpoint. " - "Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`" - ) - metric_logger = build_metric_logger(job_config, parallel_dims) - # plot losses loaded from checkpoint (if any) to TensorBoard # NOTE: Loss info after the last log step before checkpoint saving will not be ploted. # This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq @@ -271,43 +218,23 @@ def loss_fn(pred, labels): labels = labels.cuda() optimizers.zero_grad() - if parallel_dims.pp_enabled: - # pipeline parallel forward / backward inside step() call - is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 - - with train_context(): - if pp_mesh.get_local_rank() == 0: - pp_schedule.step(input_ids) - elif is_last_stage: - losses = [] - pp_schedule.step(target=labels, losses=losses) - else: - pp_schedule.step() - - # accumulate losses across pipeline microbatches - loss = ( - torch.mean(torch.stack(losses)) - if is_last_stage - else torch.Tensor([-1.0]) + with train_context(): + pred = model(input_ids) + loss = torch.nn.functional.cross_entropy( + pred.flatten(0, 1), labels.flatten(0, 1) ) - else: - # Non-PP forward / backward - with train_context(): - pred = model(input_ids) - loss = loss_fn(pred, labels) - # pred.shape=(bs, seq_len, vocab_size) - # need to free to before bwd to avoid peaking memory - del pred - loss.backward() + # pred.shape=(bs, seq_len, vocab_size) + # need to free to before bwd to avoid peaking memory + del pred + loss.backward() # clip gradients - for model in model_parts: - torch.nn.utils.clip_grad_norm_( - model.parameters(), job_config.training.max_norm, foreach=True - ) + torch.nn.utils.clip_grad_norm_( + model.parameters(), job_config.training.max_norm, foreach=True + ) # sync float8 amaxes and scales - float8_handler.sync_float8_amax_and_scale_history(model_parts) + float8_handler.sync_float8_amax_and_scale_history(model) # optimizer step checkpoint.maybe_wait_for_staging() @@ -316,7 +243,7 @@ def loss_fn(pred, labels): # calculate float8 dynamic amax/scale for all-parameter for FSDP2 # it issues a single all-reduce for all parameters at once for better performance - float8_handler.precompute_float8_dynamic_scale_for_fsdp(model_parts) + float8_handler.precompute_float8_dynamic_scale_for_fsdp(model) losses_since_last_log.append(loss)