From 236aa9228cc24b70c8575a9243db25df990c6e5c Mon Sep 17 00:00:00 2001 From: Yifu Wang Date: Tue, 25 Jun 2024 16:23:03 -0700 Subject: [PATCH 1/2] Add the option to turn on async-TP [ghstack-poisoned] --- torchtitan/config_manager.py | 6 ++++++ torchtitan/parallelisms/parallelize_llama.py | 6 ++++++ train.py | 5 +++++ 3 files changed, 17 insertions(+) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 2ff216e14..8a47a6ac0 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -241,6 +241,12 @@ def __init__(self): action="store_true", help="Whether to apply loss parallel when sequence parallel is enabled", ) + self.parser.add_argument( + "--experimental.enable_async_tensor_parallel", + default=False, + action="store_true", + help="Whether to apply async tensor parallel (currently only effective when compile is enabled)", + ) self.parser.add_argument( "--experimental.pipeline_parallel_degree", type=int, diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index cbdae71bf..be627432a 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -394,6 +394,12 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): parallelize_plan=layer_plan, ) + if job_config.experimental.enable_async_tensor_parallel: + from torch.distributed._symmetric_memory import enable_symm_mem_for_group + + torch._inductor.config._micro_pipeline_tp = True + enable_symm_mem_for_group(tp_mesh.get_group().group_name) + logger.info("Applied Tensor Parallelism to the model") return model diff --git a/train.py b/train.py index 8e55c2105..a2ae65dc0 100644 --- a/train.py +++ b/train.py @@ -252,6 +252,11 @@ def loss_fn(pred, labels): for m in model_parts ] + # for ease of testing TP in lieu of FSDP + if job_config.training.tensor_parallel_degree == world_size: + for model in model_parts: + model.to(torch.bfloat16) + init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda" for model in model_parts: model.to_empty(device=init_device) From 6fde13b4430560368b036a8c7cc93f34d184a975 Mon Sep 17 00:00:00 2001 From: Yifu Wang Date: Wed, 26 Jun 2024 16:39:24 -0700 Subject: [PATCH 2/2] Update on "Add the option to turn on async-TP" This PR adds the option to turn on async-TP (`--experimental.enable_async_tensor_parallel`). The feature is currently implemented as compiler passes on relevant patterns, so the option is currently only effective when compile is enabled. Some trace samples from llama3_70b with tp degree=8: **all-gather -> qkv projection** Baseline: image With async-TP: image **ffn -> reduce-scater** Baseline: image With async-TP: image **all-gather -> ffn** Baseline: image With async-TP: image [ghstack-poisoned] --- train.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/train.py b/train.py index a2ae65dc0..8e55c2105 100644 --- a/train.py +++ b/train.py @@ -252,11 +252,6 @@ def loss_fn(pred, labels): for m in model_parts ] - # for ease of testing TP in lieu of FSDP - if job_config.training.tensor_parallel_degree == world_size: - for model in model_parts: - model.to(torch.bfloat16) - init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda" for model in model_parts: model.to_empty(device=init_device)