From e73c23e5041f8e2052e6e2c3ae26f9090a0ce3e8 Mon Sep 17 00:00:00 2001 From: kvshbg-aws Date: Tue, 19 Aug 2025 22:59:23 +0000 Subject: [PATCH 1/3] feat: enabling metadata comm across hosts for xla devices used by pp --- test/pipelining/test_basic_pipelining.py | 197 ++++++++++++++++++ test/run_tests.sh | 9 + torch_xla/distributed/pipelining/__init__.py | 8 + .../xla_pipeline_stage_coordinator.py | 98 +++++++++ 4 files changed, 312 insertions(+) create mode 100644 test/pipelining/test_basic_pipelining.py create mode 100644 torch_xla/distributed/pipelining/__init__.py create mode 100644 torch_xla/distributed/pipelining/xla_pipeline_stage_coordinator.py diff --git a/test/pipelining/test_basic_pipelining.py b/test/pipelining/test_basic_pipelining.py new file mode 100644 index 00000000000..36d5d9760b9 --- /dev/null +++ b/test/pipelining/test_basic_pipelining.py @@ -0,0 +1,197 @@ +# Test for XLA Pipeline Parallelism - requires torchrun --nproc_per_node=2 +# Only runs on NEURON devices with exactly 2 processes + +import os +import sys +import unittest +import numpy as np +import torch +from torch import nn +import torch.optim as optim +import torch.distributed as dist + +import torch_xla.distributed.pipelining +from torch.distributed.pipelining import pipeline, SplitPoint, PipelineStage, ScheduleGPipe +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr +import torch_xla +import torch_xla.utils.utils as xu + + +class MLPModule(torch.nn.Module): + + def __init__(self, d_hid, num_classes, is_last=False): + super().__init__() + self.net1 = torch.nn.Linear(d_hid, d_hid // 2) + self.relu = torch.nn.ReLU() + self.net2 = torch.nn.Linear(d_hid // 2, num_classes if is_last else d_hid) + + def forward(self, x): + x = self.net1(x) + x = self.relu(x) + x = self.net2(x) + return x + + +class SimpleTransf(torch.nn.Module): + + def __init__(self, hidden_dim, num_classes, num_layers): + super().__init__() + self.layers = torch.nn.Sequential(*[ + MLPModule(hidden_dim, num_classes, is_last=(i == num_layers - 1)) + for i in range(num_layers) + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.layers(x) + + +class TestLazyBasicPipelining(unittest.TestCase): + + @classmethod + def setUpClass(cls): + # Only run on NEURON devices + if xr.device_type() != 'NEURON': + raise unittest.SkipTest('Test only runs on NEURON devices') + + # Require distributed environment with exactly 2 devices + if 'RANK' not in os.environ or 'WORLD_SIZE' not in os.environ: + raise unittest.SkipTest('Test requires torchrun with RANK and WORLD_SIZE') + + world_size = int(os.environ.get('WORLD_SIZE', 0)) + if world_size != 2: + raise unittest.SkipTest('Test requires exactly 2 devices') + + def test_pipeline_training(self): + """Test distributed pipeline training with GPipe scheduling and loss convergence""" + # configs + hidden_dim = 1024 + num_classes = 32 + num_layers = 2 + batch_size = 2 + num_epochs = 2 + lr = 0.01 + train_dataset_len = 1024 * 8 + gradient_accumulation_steps = 1 + logging_steps = 1 + + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + chunks = world_size + + device = torch_xla.device() + print(f"Rank {rank} using device {device}") + + # Initialize process group + dist.init_process_group( + backend="xla", rank=rank, world_size=world_size, init_method='xla://') + + torch.manual_seed(42) + model = SimpleTransf(hidden_dim, num_classes, num_layers).to("meta") + + # Define split points for pipeline parallelism + split_spec = {} + for i in range(1, chunks): + layer_idx = (i * num_layers) // chunks + split_spec[f"layers.{layer_idx}"] = SplitPoint.BEGINNING + + # Create a sample input for the pipeline + example_input = torch.randn(batch_size, hidden_dim, device="meta") + + # Create the pipeline and respective stage for the rank + pipe = pipeline(model, mb_args=(example_input,), split_spec=split_spec) + + # Assertions for pipeline creation + self.assertEqual( + pipe.num_stages, chunks, + f"Pipeline stages should match chunks: {pipe.num_stages} != {chunks}") + + loss_fn = nn.CrossEntropyLoss() + stage_mod = pipe.get_stage_module(rank) + stage_mod = stage_mod.to_empty(device=device) + stage = PipelineStage(stage_mod, rank, chunks, device) + schedule = ScheduleGPipe(stage, batch_size, loss_fn=loss_fn) + + del model + if rank == 0: + print(f"{rank=}, {schedule._get_pipeline_order()}\n", flush=True) + + losses = [] + to_check_losses = [] + optimizer = optim.SGD(stage_mod.parameters(), lr=lr) + + train_loader = xu.SampleGenerator( + data=( + torch.randn(batch_size, hidden_dim), + torch.randint(0, num_classes, (batch_size,), dtype=torch.int64), + ), + sample_count=train_dataset_len // + (batch_size * gradient_accumulation_steps), + ) + + def print_epoch(epoch, step, losses): + to_check_losses.append(losses[-1].item()) + print(f"Epoch {epoch} step {step} loss {losses[-1]}") + + print(f"{world_size=}, {rank=}") + + # Training loop with rank-specific pipeline logic + for epoch in range(num_epochs): + for step, (data, target) in enumerate(train_loader): + optimizer.zero_grad() + if rank == 0: # First rank handles input + data = data.to(device) + _ = schedule.step(data) + elif rank == world_size - 1: # Last rank handles target and loss + target = target.to(device) + _ = schedule.step(target=target, losses=losses) + else: # Middle ranks just forward/backward + _ = schedule.step() + + if rank == world_size - 1: + if step % logging_steps == 0: + xm.add_step_closure(print_epoch, (epoch, step, losses)) + + optimizer.step() + torch.distributed.barrier() + torch_xla.sync() + if step == 100: # break for assertions + break + break + # Validate training results + if to_check_losses: + if rank == 0: + self.assertTrue( + len(to_check_losses) == 0, + f"rank0 should not store losses, but it has: {to_check_losses}") + if rank == world_size - 1: + # Last rank records losses - verify they exist and are finite + self.assertGreater( + len(to_check_losses), 0, "Last rank should record losses") + for loss in to_check_losses: + self.assertTrue( + torch.isfinite(torch.Tensor([loss])), + f"Loss should be finite: {loss}") + + # Check loss convergence (early vs late training) + if len(to_check_losses) >= 10: + early_losses = [ + l for l in to_check_losses[:len(to_check_losses) // 3] + ] + late_losses = [ + l for l in to_check_losses[-len(to_check_losses) // 3:] + ] + early_avg = sum(early_losses) / len(early_losses) + late_avg = sum(late_losses) / len(late_losses) + + print( + f"Early avg loss: {early_avg:.4f}, Late avg loss: {late_avg:.4f}") + self.assertLessEqual( + late_avg, early_avg * 1.1, + f"Loss should generally decrease: early={early_avg:.4f}, late={late_avg:.4f}" + ) + dist.destroy_process_group() + + +if __name__ == '__main__': + unittest.main() diff --git a/test/run_tests.sh b/test/run_tests.sh index 54c893c7b40..1691c3f5136 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -135,6 +135,14 @@ function run_pt_xla_debug_level2 { PT_XLA_DEBUG_LEVEL=2 PT_XLA_DEBUG_FILE="/tmp/pt_xla_debug.txt" run_test "$@" } +function run_neuron_pipelining_test { + if ! test_is_selected "$1"; then + return + fi + echo "Running neuron specific pp test : $@" + PJRT_DEVICE=NEURON torchrun --nproc_per_node 2 "$@" +} + function run_torch_op_tests { run_dynamic "$_TEST_DIR/../../test/test_view_ops.py" "$@" -v TestViewOpsXLA run_test_without_functionalization "$_TEST_DIR/../../test/test_view_ops.py" "$@" -v TestViewOpsXLA @@ -276,6 +284,7 @@ function run_xla_op_tests3 { PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$_TEST_DIR/test_core_aten_ops.py" run_test "$_TEST_DIR/test_pallas.py" run_xla_ir_hlo_debug run_test "$_TEST_DIR/test_user_computation_debug_cache.py" + # run_neuron_pipelining_test "$_TEST_DIR/pipelining/test_basic_pipelining.py" # Test examples run_test "$_TEST_DIR/../examples/scan/scan_examples.py" diff --git a/torch_xla/distributed/pipelining/__init__.py b/torch_xla/distributed/pipelining/__init__.py new file mode 100644 index 00000000000..6deceba400b --- /dev/null +++ b/torch_xla/distributed/pipelining/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024, PyTorch XLA Contributors +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from .xla_pipeline_stage_coordinator import register_xla_pipeline_stage_coordinator + +register_xla_pipeline_stage_coordinator() diff --git a/torch_xla/distributed/pipelining/xla_pipeline_stage_coordinator.py b/torch_xla/distributed/pipelining/xla_pipeline_stage_coordinator.py new file mode 100644 index 00000000000..48f968a8d93 --- /dev/null +++ b/torch_xla/distributed/pipelining/xla_pipeline_stage_coordinator.py @@ -0,0 +1,98 @@ +# Copyright (c) 2024, PyTorch XLA Contributors +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +""" +XLA Pipeline Stage Coordinator for PyTorch Pipeline Parallelism + +This module provides XLA-specific pipeline stage coordination for pipeline parallelism, +inheriting from the base PipelineStageCoordinator and overriding only the methods that +need XLA-specific behavior. +""" + +import logging +import torch +import torch.distributed as dist +from torch.distributed.pipelining.pipeline_stage_coordinator import PipelineStageCoordinator + +# Import XLA modules +import torch_xla +import torch_xla.runtime as xr + +logger = logging.getLogger(__name__) + + +class XlaPipelineStageCoordinator(PipelineStageCoordinator): + """ + XLA-specific pipeline stage coordinator for pipeline parallelism. + + Inherits from the base PipelineStageCoordinator and overrides only the methods + that need XLA-specific behavior. + """ + + def __init__(self, device: torch.device, group: dist.ProcessGroup): + # check XLA device and group + assert device.type == "xla", f"XlaPipelineStageCoordinator: device {device} is not XLA device" + if group: + assert isinstance( + group, dist.ProcessGroup + ), f"XlaPipelineStageCoordinator: group {group} is not a ProcessGroup" + + # For XLA, we need to ensure the device is set to "cpu" and group is "gloo" + # regardless of the input device and group params + device = torch.device("cpu") + backend = "gloo" + ranks = list(range(xr._WORLD_SIZE)) + group = dist.new_group(backend=backend, ranks=ranks) + super().__init__(device, group) + logger.debug( + "XLAPipelineStageCoordinator: Initialized with device=%s and group=%s", + self._device, self._group.name()) + + def create_stage_communication_buffer(self, metadata: torch.Tensor, + device: torch.device) -> torch.Tensor: + """ + Create a tensor buffer from metadata for pipeline stage communication. + + XLA-specific implementation: create empty tensor on XLA device. + + Args: + metadata: The metadata object received from another stage + device: Target XLA device + + Returns: + Empty XLA tensor ready for pipeline stage communication + """ + logger.debug( + "XlaPipelineStageCoordinator: Creating stage communication buffer from metadata - shape %s, dtype %s on device %s", + metadata.shape, metadata.dtype, device) + + # For XLA, we need to ensure the device is set to "xla" + # regardless of the input device parameter + return torch.empty(metadata.shape, dtype=metadata.dtype, device="xla") + + +def register_xla_pipeline_stage_coordinator(): + """Register the XLA pipeline stage coordinator with PyTorch's registry.""" + logger.debug("Attempting to register XLA pipeline stage coordinator") + try: + # Import the registration function from PyTorch + from torch.distributed.pipelining.pipeline_stage_coordinator import register_pipeline_stage_coordinator + + # Register XLA coordinator + def create_xla_coordinator(device, group): + logger.debug("Creating XLA pipeline stage coordinator instance") + return XlaPipelineStageCoordinator(device, group) + + register_pipeline_stage_coordinator( + torch.device("xla"), create_xla_coordinator) + logger.debug("Successfully registered XLA pipeline stage coordinator") + + except ImportError as e: + logger.debug( + "Failed to register XLA pipeline stage coordinator due to import error: %s", + e) + # If PyTorch doesn't have the pipeline stage coordinator infrastructure yet, + # this will be a no-op. The coordinator can still be used directly. + pass From c2fa38e5d25412a1bfcf2758f22709cc43e9be3d Mon Sep 17 00:00:00 2001 From: kvshbg-aws Date: Tue, 19 Aug 2025 23:02:28 +0000 Subject: [PATCH 2/3] fix: adding a todo for test --- test/run_tests.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/test/run_tests.sh b/test/run_tests.sh index 1691c3f5136..1e835e5f640 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -284,6 +284,7 @@ function run_xla_op_tests3 { PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$_TEST_DIR/test_core_aten_ops.py" run_test "$_TEST_DIR/test_pallas.py" run_xla_ir_hlo_debug run_test "$_TEST_DIR/test_user_computation_debug_cache.py" + # TODO: enable the below test once https://github.com/pytorch/xla/pull/9373 is merged # run_neuron_pipelining_test "$_TEST_DIR/pipelining/test_basic_pipelining.py" # Test examples From f6b04a1a058ed37f36951607d91fd3c2b3caf6f8 Mon Sep 17 00:00:00 2001 From: kvshbg-aws Date: Thu, 21 Aug 2025 17:24:57 +0000 Subject: [PATCH 3/3] move neuron test to apt folder --- test/{ => neuron}/pipelining/test_basic_pipelining.py | 0 test/neuron/run_tests.sh | 2 ++ test/run_tests.sh | 10 ---------- 3 files changed, 2 insertions(+), 10 deletions(-) rename test/{ => neuron}/pipelining/test_basic_pipelining.py (100%) diff --git a/test/pipelining/test_basic_pipelining.py b/test/neuron/pipelining/test_basic_pipelining.py similarity index 100% rename from test/pipelining/test_basic_pipelining.py rename to test/neuron/pipelining/test_basic_pipelining.py diff --git a/test/neuron/run_tests.sh b/test/neuron/run_tests.sh index d8ee9a39b03..537033253d2 100755 --- a/test/neuron/run_tests.sh +++ b/test/neuron/run_tests.sh @@ -285,6 +285,8 @@ function run_xla_op_tests3 { function run_xla_neuron_tests { run_test "$_TEST_DIR/neuron/test_neuron_utils.py" run_test "$_TEST_DIR/neuron/test_neuron_data_types.py" + # TODO: enable the below test once https://github.com/pytorch/xla/pull/9373 is merged + # run_torchrun "$_TEST_DIR/neuron/pipelining/test_basic_pipelining.py" } ####################################################################################### diff --git a/test/run_tests.sh b/test/run_tests.sh index 1e835e5f640..54c893c7b40 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -135,14 +135,6 @@ function run_pt_xla_debug_level2 { PT_XLA_DEBUG_LEVEL=2 PT_XLA_DEBUG_FILE="/tmp/pt_xla_debug.txt" run_test "$@" } -function run_neuron_pipelining_test { - if ! test_is_selected "$1"; then - return - fi - echo "Running neuron specific pp test : $@" - PJRT_DEVICE=NEURON torchrun --nproc_per_node 2 "$@" -} - function run_torch_op_tests { run_dynamic "$_TEST_DIR/../../test/test_view_ops.py" "$@" -v TestViewOpsXLA run_test_without_functionalization "$_TEST_DIR/../../test/test_view_ops.py" "$@" -v TestViewOpsXLA @@ -284,8 +276,6 @@ function run_xla_op_tests3 { PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$_TEST_DIR/test_core_aten_ops.py" run_test "$_TEST_DIR/test_pallas.py" run_xla_ir_hlo_debug run_test "$_TEST_DIR/test_user_computation_debug_cache.py" - # TODO: enable the below test once https://github.com/pytorch/xla/pull/9373 is merged - # run_neuron_pipelining_test "$_TEST_DIR/pipelining/test_basic_pipelining.py" # Test examples run_test "$_TEST_DIR/../examples/scan/scan_examples.py"