diff --git a/tests/pytorch/test_parallel_cross_entropy.py b/tests/pytorch/test_parallel_cross_entropy.py index 77bea2b360..fa56852ffc 100644 --- a/tests/pytorch/test_parallel_cross_entropy.py +++ b/tests/pytorch/test_parallel_cross_entropy.py @@ -6,6 +6,8 @@ import torch from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy +from utils import dtype_tols + class TestParallelCrossEntropy: @@ -18,19 +20,25 @@ def generate_infra(self, reduce_loss: bool, label_smoothing: float): label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none" ) - def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool): - + def generate_input( + self, + dtype: torch.dtype, + swap_dim: bool, + ignore_idx: bool, + device: torch.device = "cuda", + ): SQ = random.choice([64, 128]) batch = random.choice([1, 2]) vocab = random.choice([64000, 128000]) ignore = random.sample(range(0, SQ - 1), 5) + # Generate random data if swap_dim: - self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype).cuda() - self.tar_test = torch.randint(0, vocab, (SQ, batch)).cuda() + self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype, device=device) + self.tar_test = torch.randint(0, vocab, (SQ, batch), device=device) else: - self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype).cuda() - self.tar_test = torch.randint(0, vocab, (batch, SQ)).cuda() + self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype, device=device) + self.tar_test = torch.randint(0, vocab, (batch, SQ), device=device) if ignore_idx: for i in ignore: @@ -40,9 +48,14 @@ def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool): else: self.tar_test[0][i] = -100 + # Make copy of data for reference implementation self.input_ref = torch.reshape(self.input_test.clone().detach(), (batch * SQ, vocab)) self.tar_ref = torch.reshape(self.tar_test.clone().detach(), (batch * SQ,)) + # Enable autograd + self.input_test.requires_grad_() + self.input_ref.requires_grad_() + def one_iteration_test( self, dtype: torch.dtype, @@ -52,18 +65,20 @@ def one_iteration_test( ignore_idx: bool = False, ): + # Random data self.generate_input(dtype, swap_dim, ignore_idx) - self.input_test.requires_grad_(True) - self.input_ref.requires_grad_(True) - + # Forward pass test_loss = self.test_loss_func( self.input_test, self.tar_test, label_smoothing, reduce_loss, None ) - ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref) - # Handle backward pass based on the test scenario + # Compute square to avoid trivial backward pass + test_loss = torch.square(test_loss) + ref_loss = torch.square(ref_loss) + + # Backward pass if reduce_loss: test_loss.backward() ref_loss.backward() @@ -71,16 +86,18 @@ def one_iteration_test( test_loss.sum().backward() ref_loss.sum().backward() - test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss - - if ignore_idx: - print(test_loss, ref_loss) - - # Compare gradients when backward pass was called - torch.testing.assert_close( - torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad - ) - + # Check that loss and grad input match + tols = dtype_tols(dtype) + test_loss = test_loss.to(dtype=torch.float64, device="cpu") + ref_loss = test_loss.to(dtype=torch.float64, device="cpu") + ref_loss = ref_loss.reshape(test_loss.size()) + test_grad_input = self.input_test.grad.to(dtype=torch.float64, device="cpu") + ref_grad_input = self.input_ref.grad.to(dtype=torch.float64, device="cpu") + ref_grad_input = ref_grad_input.reshape(test_grad_input.size()) + torch.testing.assert_close(test_loss, ref_loss, **tols) + torch.testing.assert_close(test_grad_input, ref_grad_input, **tols) + + # Reset data self.input_test = None self.input_ref = None self.tar_test = None diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index 323a939223..7cfff1da9d 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -230,6 +230,7 @@ def element_mul_kernel( X_ptr, X_stride, grad_output_ptr, + grad_output_stride, n_cols, BLOCK_SIZE: tl.constexpr, ): @@ -252,6 +253,7 @@ def element_mul_kernel( X_ptr += program_id * X_stride # Load the gradient output value + grad_output_ptr += program_id * grad_output_stride grad_output = tl.load(grad_output_ptr) # Perform the element-wise multiplication @@ -360,6 +362,7 @@ def cross_entropy_backward( _input, _input.stride(-2), grad_output, + 1 if grad_output.numel() > 1 else 0, V, BLOCK_SIZE=BLOCK_SIZE, num_warps=32,