diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 2ff216e14..8a47a6ac0 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -241,6 +241,12 @@ def __init__(self): action="store_true", help="Whether to apply loss parallel when sequence parallel is enabled", ) + self.parser.add_argument( + "--experimental.enable_async_tensor_parallel", + default=False, + action="store_true", + help="Whether to apply async tensor parallel (currently only effective when compile is enabled)", + ) self.parser.add_argument( "--experimental.pipeline_parallel_degree", type=int, diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index cbdae71bf..be627432a 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -394,6 +394,12 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): parallelize_plan=layer_plan, ) + if job_config.experimental.enable_async_tensor_parallel: + from torch.distributed._symmetric_memory import enable_symm_mem_for_group + + torch._inductor.config._micro_pipeline_tp = True + enable_symm_mem_for_group(tp_mesh.get_group().group_name) + logger.info("Applied Tensor Parallelism to the model") return model