Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 38 additions & 21 deletions tests/pytorch/test_parallel_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import torch
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy

from utils import dtype_tols


class TestParallelCrossEntropy:

Expand All @@ -18,19 +20,25 @@ def generate_infra(self, reduce_loss: bool, label_smoothing: float):
label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none"
)

def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool):

def generate_input(
self,
dtype: torch.dtype,
swap_dim: bool,
ignore_idx: bool,
device: torch.device = "cuda",
):
SQ = random.choice([64, 128])
batch = random.choice([1, 2])
vocab = random.choice([64000, 128000])
ignore = random.sample(range(0, SQ - 1), 5)

# Generate random data
if swap_dim:
self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype).cuda()
self.tar_test = torch.randint(0, vocab, (SQ, batch)).cuda()
self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype, device=device)
self.tar_test = torch.randint(0, vocab, (SQ, batch), device=device)
else:
self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype).cuda()
self.tar_test = torch.randint(0, vocab, (batch, SQ)).cuda()
self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype, device=device)
self.tar_test = torch.randint(0, vocab, (batch, SQ), device=device)

if ignore_idx:
for i in ignore:
Expand All @@ -40,9 +48,14 @@ def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool):
else:
self.tar_test[0][i] = -100

# Make copy of data for reference implementation
self.input_ref = torch.reshape(self.input_test.clone().detach(), (batch * SQ, vocab))
self.tar_ref = torch.reshape(self.tar_test.clone().detach(), (batch * SQ,))

# Enable autograd
self.input_test.requires_grad_()
self.input_ref.requires_grad_()

def one_iteration_test(
self,
dtype: torch.dtype,
Expand All @@ -52,35 +65,39 @@ def one_iteration_test(
ignore_idx: bool = False,
):

# Random data
self.generate_input(dtype, swap_dim, ignore_idx)

self.input_test.requires_grad_(True)
self.input_ref.requires_grad_(True)

# Forward pass
test_loss = self.test_loss_func(
self.input_test, self.tar_test, label_smoothing, reduce_loss, None
)

ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref)

# Handle backward pass based on the test scenario
# Compute square to avoid trivial backward pass
test_loss = torch.square(test_loss)
ref_loss = torch.square(ref_loss)

# Backward pass
if reduce_loss:
test_loss.backward()
ref_loss.backward()
else:
test_loss.sum().backward()
ref_loss.sum().backward()

test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss

if ignore_idx:
print(test_loss, ref_loss)

# Compare gradients when backward pass was called
torch.testing.assert_close(
torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad
)

# Check that loss and grad input match
tols = dtype_tols(dtype)
test_loss = test_loss.to(dtype=torch.float64, device="cpu")
ref_loss = test_loss.to(dtype=torch.float64, device="cpu")
ref_loss = ref_loss.reshape(test_loss.size())
test_grad_input = self.input_test.grad.to(dtype=torch.float64, device="cpu")
ref_grad_input = self.input_ref.grad.to(dtype=torch.float64, device="cpu")
ref_grad_input = ref_grad_input.reshape(test_grad_input.size())
torch.testing.assert_close(test_loss, ref_loss, **tols)
torch.testing.assert_close(test_grad_input, ref_grad_input, **tols)

# Reset data
self.input_test = None
self.input_ref = None
self.tar_test = None
Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/triton/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def element_mul_kernel(
X_ptr,
X_stride,
grad_output_ptr,
grad_output_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
Expand All @@ -252,6 +253,7 @@ def element_mul_kernel(
X_ptr += program_id * X_stride

# Load the gradient output value
grad_output_ptr += program_id * grad_output_stride
grad_output = tl.load(grad_output_ptr)

# Perform the element-wise multiplication
Expand Down Expand Up @@ -360,6 +362,7 @@ def cross_entropy_backward(
_input,
_input.stride(-2),
grad_output,
1 if grad_output.numel() > 1 else 0,
V,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32,
Expand Down