Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 197 additions & 0 deletions test/neuron/pipelining/test_basic_pipelining.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# Test for XLA Pipeline Parallelism - requires torchrun --nproc_per_node=2
# Only runs on NEURON devices with exactly 2 processes

import os
import sys
import unittest
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torch.distributed as dist

import torch_xla.distributed.pipelining
from torch.distributed.pipelining import pipeline, SplitPoint, PipelineStage, ScheduleGPipe
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla
import torch_xla.utils.utils as xu


class MLPModule(torch.nn.Module):

def __init__(self, d_hid, num_classes, is_last=False):
super().__init__()
self.net1 = torch.nn.Linear(d_hid, d_hid // 2)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(d_hid // 2, num_classes if is_last else d_hid)

def forward(self, x):
x = self.net1(x)
x = self.relu(x)
x = self.net2(x)
return x


class SimpleTransf(torch.nn.Module):

def __init__(self, hidden_dim, num_classes, num_layers):
super().__init__()
self.layers = torch.nn.Sequential(*[
MLPModule(hidden_dim, num_classes, is_last=(i == num_layers - 1))
for i in range(num_layers)
])

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layers(x)


class TestLazyBasicPipelining(unittest.TestCase):

@classmethod
def setUpClass(cls):
# Only run on NEURON devices
if xr.device_type() != 'NEURON':
raise unittest.SkipTest('Test only runs on NEURON devices')

# Require distributed environment with exactly 2 devices
if 'RANK' not in os.environ or 'WORLD_SIZE' not in os.environ:
raise unittest.SkipTest('Test requires torchrun with RANK and WORLD_SIZE')

world_size = int(os.environ.get('WORLD_SIZE', 0))
if world_size != 2:
raise unittest.SkipTest('Test requires exactly 2 devices')

def test_pipeline_training(self):
"""Test distributed pipeline training with GPipe scheduling and loss convergence"""
# configs
hidden_dim = 1024
num_classes = 32
num_layers = 2
batch_size = 2
num_epochs = 2
lr = 0.01
train_dataset_len = 1024 * 8
gradient_accumulation_steps = 1
logging_steps = 1

rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
chunks = world_size

device = torch_xla.device()
print(f"Rank {rank} using device {device}")

# Initialize process group
dist.init_process_group(
backend="xla", rank=rank, world_size=world_size, init_method='xla://')

torch.manual_seed(42)
model = SimpleTransf(hidden_dim, num_classes, num_layers).to("meta")

# Define split points for pipeline parallelism
split_spec = {}
for i in range(1, chunks):
layer_idx = (i * num_layers) // chunks
split_spec[f"layers.{layer_idx}"] = SplitPoint.BEGINNING

# Create a sample input for the pipeline
example_input = torch.randn(batch_size, hidden_dim, device="meta")

# Create the pipeline and respective stage for the rank
pipe = pipeline(model, mb_args=(example_input,), split_spec=split_spec)

# Assertions for pipeline creation
self.assertEqual(
pipe.num_stages, chunks,
f"Pipeline stages should match chunks: {pipe.num_stages} != {chunks}")

loss_fn = nn.CrossEntropyLoss()
stage_mod = pipe.get_stage_module(rank)
stage_mod = stage_mod.to_empty(device=device)
stage = PipelineStage(stage_mod, rank, chunks, device)
schedule = ScheduleGPipe(stage, batch_size, loss_fn=loss_fn)

del model
if rank == 0:
print(f"{rank=}, {schedule._get_pipeline_order()}\n", flush=True)

losses = []
to_check_losses = []
optimizer = optim.SGD(stage_mod.parameters(), lr=lr)

train_loader = xu.SampleGenerator(
data=(
torch.randn(batch_size, hidden_dim),
torch.randint(0, num_classes, (batch_size,), dtype=torch.int64),
),
sample_count=train_dataset_len //
(batch_size * gradient_accumulation_steps),
)

def print_epoch(epoch, step, losses):
to_check_losses.append(losses[-1].item())
print(f"Epoch {epoch} step {step} loss {losses[-1]}")

print(f"{world_size=}, {rank=}")

# Training loop with rank-specific pipeline logic
for epoch in range(num_epochs):
for step, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
if rank == 0: # First rank handles input
data = data.to(device)
_ = schedule.step(data)
elif rank == world_size - 1: # Last rank handles target and loss
target = target.to(device)
_ = schedule.step(target=target, losses=losses)
else: # Middle ranks just forward/backward
_ = schedule.step()

if rank == world_size - 1:
if step % logging_steps == 0:
xm.add_step_closure(print_epoch, (epoch, step, losses))

optimizer.step()
torch.distributed.barrier()
torch_xla.sync()
if step == 100: # break for assertions
break
break
# Validate training results
if to_check_losses:
if rank == 0:
self.assertTrue(
len(to_check_losses) == 0,
f"rank0 should not store losses, but it has: {to_check_losses}")
if rank == world_size - 1:
# Last rank records losses - verify they exist and are finite
self.assertGreater(
len(to_check_losses), 0, "Last rank should record losses")
for loss in to_check_losses:
self.assertTrue(
torch.isfinite(torch.Tensor([loss])),
f"Loss should be finite: {loss}")

# Check loss convergence (early vs late training)
if len(to_check_losses) >= 10:
early_losses = [
l for l in to_check_losses[:len(to_check_losses) // 3]
]
late_losses = [
l for l in to_check_losses[-len(to_check_losses) // 3:]
]
early_avg = sum(early_losses) / len(early_losses)
late_avg = sum(late_losses) / len(late_losses)

print(
f"Early avg loss: {early_avg:.4f}, Late avg loss: {late_avg:.4f}")
self.assertLessEqual(
late_avg, early_avg * 1.1,
f"Loss should generally decrease: early={early_avg:.4f}, late={late_avg:.4f}"
)
dist.destroy_process_group()


if __name__ == '__main__':
unittest.main()
2 changes: 2 additions & 0 deletions test/neuron/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ function run_xla_op_tests3 {
function run_xla_neuron_tests {
run_test "$_TEST_DIR/neuron/test_neuron_utils.py"
run_test "$_TEST_DIR/neuron/test_neuron_data_types.py"
# TODO: enable the below test once https://github.com/pytorch/xla/pull/9373 is merged
# run_torchrun "$_TEST_DIR/neuron/pipelining/test_basic_pipelining.py"
}

#######################################################################################
Expand Down
8 changes: 8 additions & 0 deletions torch_xla/distributed/pipelining/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) 2024, PyTorch XLA Contributors
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .xla_pipeline_stage_coordinator import register_xla_pipeline_stage_coordinator

register_xla_pipeline_stage_coordinator()
98 changes: 98 additions & 0 deletions torch_xla/distributed/pipelining/xla_pipeline_stage_coordinator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) 2024, PyTorch XLA Contributors
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
XLA Pipeline Stage Coordinator for PyTorch Pipeline Parallelism

This module provides XLA-specific pipeline stage coordination for pipeline parallelism,
inheriting from the base PipelineStageCoordinator and overriding only the methods that
need XLA-specific behavior.
"""

import logging
import torch
import torch.distributed as dist
from torch.distributed.pipelining.pipeline_stage_coordinator import PipelineStageCoordinator

# Import XLA modules
import torch_xla
import torch_xla.runtime as xr

logger = logging.getLogger(__name__)


class XlaPipelineStageCoordinator(PipelineStageCoordinator):
"""
XLA-specific pipeline stage coordinator for pipeline parallelism.

Inherits from the base PipelineStageCoordinator and overrides only the methods
that need XLA-specific behavior.
"""

def __init__(self, device: torch.device, group: dist.ProcessGroup):
# check XLA device and group
assert device.type == "xla", f"XlaPipelineStageCoordinator: device {device} is not XLA device"
if group:
assert isinstance(
group, dist.ProcessGroup
), f"XlaPipelineStageCoordinator: group {group} is not a ProcessGroup"

# For XLA, we need to ensure the device is set to "cpu" and group is "gloo"
# regardless of the input device and group params
device = torch.device("cpu")
backend = "gloo"
ranks = list(range(xr._WORLD_SIZE))
group = dist.new_group(backend=backend, ranks=ranks)
super().__init__(device, group)
logger.debug(
"XLAPipelineStageCoordinator: Initialized with device=%s and group=%s",
self._device, self._group.name())

def create_stage_communication_buffer(self, metadata: torch.Tensor,
device: torch.device) -> torch.Tensor:
"""
Create a tensor buffer from metadata for pipeline stage communication.

XLA-specific implementation: create empty tensor on XLA device.

Args:
metadata: The metadata object received from another stage
device: Target XLA device

Returns:
Empty XLA tensor ready for pipeline stage communication
"""
logger.debug(
"XlaPipelineStageCoordinator: Creating stage communication buffer from metadata - shape %s, dtype %s on device %s",
metadata.shape, metadata.dtype, device)

# For XLA, we need to ensure the device is set to "xla"
# regardless of the input device parameter
return torch.empty(metadata.shape, dtype=metadata.dtype, device="xla")


def register_xla_pipeline_stage_coordinator():
"""Register the XLA pipeline stage coordinator with PyTorch's registry."""
logger.debug("Attempting to register XLA pipeline stage coordinator")
try:
# Import the registration function from PyTorch
from torch.distributed.pipelining.pipeline_stage_coordinator import register_pipeline_stage_coordinator

# Register XLA coordinator
def create_xla_coordinator(device, group):
logger.debug("Creating XLA pipeline stage coordinator instance")
return XlaPipelineStageCoordinator(device, group)

register_pipeline_stage_coordinator(
torch.device("xla"), create_xla_coordinator)
logger.debug("Successfully registered XLA pipeline stage coordinator")

except ImportError as e:
logger.debug(
"Failed to register XLA pipeline stage coordinator due to import error: %s",
e)
# If PyTorch doesn't have the pipeline stage coordinator infrastructure yet,
# this will be a no-op. The coordinator can still be used directly.
pass