-
Notifications
You must be signed in to change notification settings - Fork 563
Adding functionality for metadata communication across hosts #9570
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
kvshbg-aws
wants to merge
4
commits into
pytorch:master
Choose a base branch
from
kvshbg-aws:pp_xla
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
# Test for XLA Pipeline Parallelism - requires torchrun --nproc_per_node=2 | ||
# Only runs on NEURON devices with exactly 2 processes | ||
|
||
import os | ||
import sys | ||
import unittest | ||
import numpy as np | ||
import torch | ||
from torch import nn | ||
import torch.optim as optim | ||
import torch.distributed as dist | ||
|
||
import torch_xla.distributed.pipelining | ||
from torch.distributed.pipelining import pipeline, SplitPoint, PipelineStage, ScheduleGPipe | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.runtime as xr | ||
import torch_xla | ||
import torch_xla.utils.utils as xu | ||
|
||
|
||
class MLPModule(torch.nn.Module): | ||
|
||
def __init__(self, d_hid, num_classes, is_last=False): | ||
super().__init__() | ||
self.net1 = torch.nn.Linear(d_hid, d_hid // 2) | ||
self.relu = torch.nn.ReLU() | ||
self.net2 = torch.nn.Linear(d_hid // 2, num_classes if is_last else d_hid) | ||
|
||
def forward(self, x): | ||
x = self.net1(x) | ||
x = self.relu(x) | ||
x = self.net2(x) | ||
return x | ||
|
||
|
||
class SimpleTransf(torch.nn.Module): | ||
|
||
def __init__(self, hidden_dim, num_classes, num_layers): | ||
super().__init__() | ||
self.layers = torch.nn.Sequential(*[ | ||
MLPModule(hidden_dim, num_classes, is_last=(i == num_layers - 1)) | ||
for i in range(num_layers) | ||
]) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
return self.layers(x) | ||
|
||
|
||
class TestLazyBasicPipelining(unittest.TestCase): | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
# Only run on NEURON devices | ||
if xr.device_type() != 'NEURON': | ||
raise unittest.SkipTest('Test only runs on NEURON devices') | ||
|
||
# Require distributed environment with exactly 2 devices | ||
if 'RANK' not in os.environ or 'WORLD_SIZE' not in os.environ: | ||
raise unittest.SkipTest('Test requires torchrun with RANK and WORLD_SIZE') | ||
|
||
world_size = int(os.environ.get('WORLD_SIZE', 0)) | ||
if world_size != 2: | ||
raise unittest.SkipTest('Test requires exactly 2 devices') | ||
|
||
def test_pipeline_training(self): | ||
"""Test distributed pipeline training with GPipe scheduling and loss convergence""" | ||
# configs | ||
hidden_dim = 1024 | ||
num_classes = 32 | ||
num_layers = 2 | ||
batch_size = 2 | ||
num_epochs = 2 | ||
lr = 0.01 | ||
train_dataset_len = 1024 * 8 | ||
gradient_accumulation_steps = 1 | ||
logging_steps = 1 | ||
|
||
rank = int(os.environ["RANK"]) | ||
world_size = int(os.environ["WORLD_SIZE"]) | ||
chunks = world_size | ||
|
||
device = torch_xla.device() | ||
print(f"Rank {rank} using device {device}") | ||
|
||
# Initialize process group | ||
dist.init_process_group( | ||
backend="xla", rank=rank, world_size=world_size, init_method='xla://') | ||
|
||
torch.manual_seed(42) | ||
model = SimpleTransf(hidden_dim, num_classes, num_layers).to("meta") | ||
|
||
# Define split points for pipeline parallelism | ||
split_spec = {} | ||
for i in range(1, chunks): | ||
layer_idx = (i * num_layers) // chunks | ||
split_spec[f"layers.{layer_idx}"] = SplitPoint.BEGINNING | ||
|
||
# Create a sample input for the pipeline | ||
example_input = torch.randn(batch_size, hidden_dim, device="meta") | ||
|
||
# Create the pipeline and respective stage for the rank | ||
pipe = pipeline(model, mb_args=(example_input,), split_spec=split_spec) | ||
|
||
# Assertions for pipeline creation | ||
self.assertEqual( | ||
pipe.num_stages, chunks, | ||
f"Pipeline stages should match chunks: {pipe.num_stages} != {chunks}") | ||
|
||
loss_fn = nn.CrossEntropyLoss() | ||
stage_mod = pipe.get_stage_module(rank) | ||
stage_mod = stage_mod.to_empty(device=device) | ||
stage = PipelineStage(stage_mod, rank, chunks, device) | ||
schedule = ScheduleGPipe(stage, batch_size, loss_fn=loss_fn) | ||
|
||
del model | ||
if rank == 0: | ||
print(f"{rank=}, {schedule._get_pipeline_order()}\n", flush=True) | ||
|
||
losses = [] | ||
to_check_losses = [] | ||
optimizer = optim.SGD(stage_mod.parameters(), lr=lr) | ||
|
||
train_loader = xu.SampleGenerator( | ||
data=( | ||
torch.randn(batch_size, hidden_dim), | ||
torch.randint(0, num_classes, (batch_size,), dtype=torch.int64), | ||
), | ||
sample_count=train_dataset_len // | ||
(batch_size * gradient_accumulation_steps), | ||
) | ||
|
||
def print_epoch(epoch, step, losses): | ||
to_check_losses.append(losses[-1].item()) | ||
print(f"Epoch {epoch} step {step} loss {losses[-1]}") | ||
|
||
print(f"{world_size=}, {rank=}") | ||
|
||
# Training loop with rank-specific pipeline logic | ||
for epoch in range(num_epochs): | ||
for step, (data, target) in enumerate(train_loader): | ||
optimizer.zero_grad() | ||
if rank == 0: # First rank handles input | ||
data = data.to(device) | ||
_ = schedule.step(data) | ||
elif rank == world_size - 1: # Last rank handles target and loss | ||
target = target.to(device) | ||
_ = schedule.step(target=target, losses=losses) | ||
else: # Middle ranks just forward/backward | ||
_ = schedule.step() | ||
|
||
if rank == world_size - 1: | ||
if step % logging_steps == 0: | ||
xm.add_step_closure(print_epoch, (epoch, step, losses)) | ||
|
||
optimizer.step() | ||
torch.distributed.barrier() | ||
torch_xla.sync() | ||
if step == 100: # break for assertions | ||
break | ||
break | ||
# Validate training results | ||
if to_check_losses: | ||
if rank == 0: | ||
self.assertTrue( | ||
len(to_check_losses) == 0, | ||
f"rank0 should not store losses, but it has: {to_check_losses}") | ||
if rank == world_size - 1: | ||
# Last rank records losses - verify they exist and are finite | ||
self.assertGreater( | ||
len(to_check_losses), 0, "Last rank should record losses") | ||
for loss in to_check_losses: | ||
self.assertTrue( | ||
torch.isfinite(torch.Tensor([loss])), | ||
f"Loss should be finite: {loss}") | ||
|
||
# Check loss convergence (early vs late training) | ||
if len(to_check_losses) >= 10: | ||
early_losses = [ | ||
l for l in to_check_losses[:len(to_check_losses) // 3] | ||
] | ||
late_losses = [ | ||
l for l in to_check_losses[-len(to_check_losses) // 3:] | ||
] | ||
early_avg = sum(early_losses) / len(early_losses) | ||
late_avg = sum(late_losses) / len(late_losses) | ||
|
||
print( | ||
f"Early avg loss: {early_avg:.4f}, Late avg loss: {late_avg:.4f}") | ||
self.assertLessEqual( | ||
late_avg, early_avg * 1.1, | ||
f"Loss should generally decrease: early={early_avg:.4f}, late={late_avg:.4f}" | ||
) | ||
dist.destroy_process_group() | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Copyright (c) 2024, PyTorch XLA Contributors | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from .xla_pipeline_stage_coordinator import register_xla_pipeline_stage_coordinator | ||
|
||
register_xla_pipeline_stage_coordinator() |
98 changes: 98 additions & 0 deletions
98
torch_xla/distributed/pipelining/xla_pipeline_stage_coordinator.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Copyright (c) 2024, PyTorch XLA Contributors | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
""" | ||
XLA Pipeline Stage Coordinator for PyTorch Pipeline Parallelism | ||
|
||
This module provides XLA-specific pipeline stage coordination for pipeline parallelism, | ||
inheriting from the base PipelineStageCoordinator and overriding only the methods that | ||
need XLA-specific behavior. | ||
""" | ||
|
||
import logging | ||
import torch | ||
import torch.distributed as dist | ||
from torch.distributed.pipelining.pipeline_stage_coordinator import PipelineStageCoordinator | ||
|
||
# Import XLA modules | ||
import torch_xla | ||
import torch_xla.runtime as xr | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class XlaPipelineStageCoordinator(PipelineStageCoordinator): | ||
""" | ||
XLA-specific pipeline stage coordinator for pipeline parallelism. | ||
|
||
Inherits from the base PipelineStageCoordinator and overrides only the methods | ||
that need XLA-specific behavior. | ||
""" | ||
|
||
def __init__(self, device: torch.device, group: dist.ProcessGroup): | ||
# check XLA device and group | ||
assert device.type == "xla", f"XlaPipelineStageCoordinator: device {device} is not XLA device" | ||
if group: | ||
assert isinstance( | ||
group, dist.ProcessGroup | ||
), f"XlaPipelineStageCoordinator: group {group} is not a ProcessGroup" | ||
|
||
# For XLA, we need to ensure the device is set to "cpu" and group is "gloo" | ||
# regardless of the input device and group params | ||
device = torch.device("cpu") | ||
backend = "gloo" | ||
ranks = list(range(xr._WORLD_SIZE)) | ||
group = dist.new_group(backend=backend, ranks=ranks) | ||
super().__init__(device, group) | ||
logger.debug( | ||
"XLAPipelineStageCoordinator: Initialized with device=%s and group=%s", | ||
self._device, self._group.name()) | ||
|
||
def create_stage_communication_buffer(self, metadata: torch.Tensor, | ||
device: torch.device) -> torch.Tensor: | ||
""" | ||
Create a tensor buffer from metadata for pipeline stage communication. | ||
|
||
XLA-specific implementation: create empty tensor on XLA device. | ||
|
||
Args: | ||
metadata: The metadata object received from another stage | ||
device: Target XLA device | ||
|
||
Returns: | ||
Empty XLA tensor ready for pipeline stage communication | ||
""" | ||
logger.debug( | ||
"XlaPipelineStageCoordinator: Creating stage communication buffer from metadata - shape %s, dtype %s on device %s", | ||
metadata.shape, metadata.dtype, device) | ||
|
||
# For XLA, we need to ensure the device is set to "xla" | ||
# regardless of the input device parameter | ||
return torch.empty(metadata.shape, dtype=metadata.dtype, device="xla") | ||
|
||
|
||
def register_xla_pipeline_stage_coordinator(): | ||
"""Register the XLA pipeline stage coordinator with PyTorch's registry.""" | ||
logger.debug("Attempting to register XLA pipeline stage coordinator") | ||
try: | ||
# Import the registration function from PyTorch | ||
from torch.distributed.pipelining.pipeline_stage_coordinator import register_pipeline_stage_coordinator | ||
kvshbg-aws marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Register XLA coordinator | ||
def create_xla_coordinator(device, group): | ||
logger.debug("Creating XLA pipeline stage coordinator instance") | ||
return XlaPipelineStageCoordinator(device, group) | ||
|
||
register_pipeline_stage_coordinator( | ||
torch.device("xla"), create_xla_coordinator) | ||
kvshbg-aws marked this conversation as resolved.
Show resolved
Hide resolved
|
||
logger.debug("Successfully registered XLA pipeline stage coordinator") | ||
|
||
except ImportError as e: | ||
logger.debug( | ||
"Failed to register XLA pipeline stage coordinator due to import error: %s", | ||
e) | ||
# If PyTorch doesn't have the pipeline stage coordinator infrastructure yet, | ||
# this will be a no-op. The coordinator can still be used directly. | ||
pass |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.