From 106c4cc7f1c3ce2cebcac94ed4659605f0c9fb18 Mon Sep 17 00:00:00 2001 From: Casper Date: Fri, 29 Aug 2025 20:51:20 +0200 Subject: [PATCH 01/10] fix cross entropy Signed-off-by: Casper --- .../pytorch/triton/cross_entropy.py | 116 +++++------------- 1 file changed, 33 insertions(+), 83 deletions(-) diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index 323a939223..dee9c98b36 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -96,32 +96,12 @@ def cross_entropy_kernel( world_size, ignore_idx, n_cols, - n_non_ignore, + n_rows, # NEW: total rows (B * SQ), used for all-gather offset + n_non_ignore, # number of non-ignored tokens, used for scaling reduce_loss: tl.constexpr, label_smoothing: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): - """ - This kernel computes both cross entropy loss and the gradient of the input. - - Parameters: - X_ptr: Pointer to input tensor. - X_stride (int): The stride of the input tensor. - Y_ptr: Pointer to target tensor. - Y_stride (int): The stride of the target tensor. - loss_ptr: Pointer to tensor to store the loss. - loss_stride (int): The stride of the loss tensor. - m_d_X_y_ptr: Pointer to m/d/X_y tensor. - m_d_X_y_stride: The stride of m/d/X_y tensor. - rank (int): The rank of this device in the TP group. - world_size (int): The size of world involved in this distributed loss calculation. - ignore_idx (int): Tokens to be ignored for loss and gradient calculation. - n_cols (int): The number of columns in the input tensor. - n_non_ignore (int): The number of non-ignored elements in the batch. - label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. - BLOCK_SIZE (int): The block size for Triton operations. - """ - program_id = tl.program_id(0).to(tl.int64) # locate the start index @@ -136,18 +116,20 @@ def cross_entropy_kernel( for i in range(0, n_cols, BLOCK_SIZE): X_offsets = i + tl.arange(0, BLOCK_SIZE) tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols) + # leave loss at 0 (loss buffer is initialized to 0 in Python) return loss_ptr += program_id * loss_stride m_d_X_y_ptr += program_id * 3 * m_d_X_y_stride - # Need to reduce the m/d/X_y values from other TP ranks + # Reduce m/d/X_y across TP ranks m = tl.load(m_d_X_y_ptr) d = tl.load(m_d_X_y_ptr + m_d_X_y_stride) ori_X_y = tl.load(m_d_X_y_ptr + (2 * m_d_X_y_stride)) for i in range(1, world_size): - offset = i * 3 * n_non_ignore * m_d_X_y_stride + # IMPORTANT: hop by n_rows (NOT n_non_ignore) + offset = i * 3 * n_rows * m_d_X_y_stride access_ptr = m_d_X_y_ptr + offset m_new = tl.load(access_ptr) d_new = tl.load(access_ptr + m_d_X_y_stride) @@ -157,63 +139,43 @@ def cross_entropy_kernel( m = tl.maximum(m, m_new) ori_X_y = tl.maximum(ori_X_y, X_y_new) - # Label smoothing is a general case of normal cross entropy + # Label smoothing setup scaled_x_sum = 0.0 eps = label_smoothing / (n_cols * world_size) - # 4. [Online softmax] second pass: calculate the gradients - # dx_y = (softmax(x_y) - 1) / N - # dx_i = softmax(x_i) / N, i != y - # N is the number of non ignored elements in the batch - # For label smoothing: - # dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y - # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N - # = dx_i - (1 - label_smoothing) / N + # Second pass: write gradients into X_ptr for i in range(0, n_cols, BLOCK_SIZE): X_offsets = i + tl.arange(0, BLOCK_SIZE) X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")) grad_dtype = X_block.dtype X_block = X_block.to(tl.float32) if label_smoothing > 0: - # scale X beforehand to avoid overflow scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) - # Scale gradients based on reduction mode - # For reduce_loss=True: PyTorch will scale by 1/n_rows, so we need to scale by n_rows/n_non_ignore - # For reduce_loss=False: No additional scaling from PyTorch, so we don't scale here + if reduce_loss: - X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) + # scale by the true number of non-ignored tokens + X_block = (tl.exp(X_block - m) / d - eps) / n_non_ignore else: X_block = tl.exp(X_block - m) / d - eps + tl.store(X_ptr + X_offsets, X_block.to(grad_dtype), mask=X_offsets < n_cols) - # We need tl.debug_barrier() to ensure the new result of X_ptr is written tl.debug_barrier() - # 5. Calculate the loss - - # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) - # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) + # Loss for this row loss = -(ori_X_y - m - tl.log(d)) - - # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps - # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) - # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) - # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: - # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd)) - # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 if label_smoothing > 0: smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d)) loss = loss * (1 - label_smoothing) + smooth_loss - # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N` + # Handle target position vocab_start_idx = rank * n_cols vocab_end_idx = (rank + 1) * n_cols if y >= vocab_start_idx: if y < vocab_end_idx: X_y = tl.load(X_ptr + y - vocab_start_idx) - # Apply the same conditional scaling logic for the target token if reduce_loss: - X_y += -(1 - label_smoothing) / (n_non_ignore) + X_y += -(1 - label_smoothing) / n_non_ignore else: X_y += -(1 - label_smoothing) tl.store(X_ptr + y - vocab_start_idx, X_y) @@ -230,31 +192,16 @@ def element_mul_kernel( X_ptr, X_stride, grad_output_ptr, + grad_output_stride, # NEW n_cols, BLOCK_SIZE: tl.constexpr, ): - """ - This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. - The multiplication is performed in-place on the tensor pointed by X_ptr. - - Parameters: - X_ptr: Pointer to the input tensor. - X_stride (int): The stride of the input tensor. - grad_output_ptr: Pointer to the gradient output value. - n_cols (int): The number of columns in the input tensor. - BLOCK_SIZE (int): The block size for Triton operations. - """ - - # Get the program ID and convert it to int64 to avoid overflow program_id = tl.program_id(0).to(tl.int64) - # Locate the start index X_ptr += program_id * X_stride - - # Load the gradient output value + grad_output_ptr += program_id * grad_output_stride grad_output = tl.load(grad_output_ptr) - # Perform the element-wise multiplication for i in range(0, n_cols, BLOCK_SIZE): X_offsets = i + tl.arange(0, BLOCK_SIZE) X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) @@ -269,8 +216,6 @@ def cross_entropy_forward( dist_process_group: Union[dist.ProcessGroup, None], ignore_idx: int, ): - """Forward implementation of Cross Entropy kernel""" - B, SQ, V = _input.shape n_rows = B * SQ @@ -278,13 +223,9 @@ def cross_entropy_forward( BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) - # unreduced loss loss_1d = torch.zeros(n_rows, dtype=torch.float32, device=_input.device) - - # tensor to hold this rank's m/d/X_y values m_d_X_y = torch.zeros(n_rows * 3, dtype=torch.float32, device=_input.device) - # ensure _input and target are contiguous in the last dimension if _input.stride(-1) != 1: _input = _input.contiguous() if target.stride(-1) != 1: @@ -296,7 +237,7 @@ def cross_entropy_forward( X_ptr=_input, X_stride=_input.stride(-2), Y_ptr=target, - Y_stride=target.stride(-1), # always 1 + Y_stride=target.stride(-1), m_d_X_y_ptr=m_d_X_y, m_d_X_y_stride=m_d_X_y.stride(-1), rank=rank, @@ -315,6 +256,10 @@ def cross_entropy_forward( else: m_d_X_y_gathered = m_d_X_y + # true number of non-ignored tokens + n_non_ignore = int((target != ignore_idx).sum().item()) + denom = max(n_non_ignore, 1) # avoid div-by-zero in degenerate batches + cross_entropy_kernel[(n_rows,)]( X_ptr=_input, X_stride=_input.stride(-2), @@ -328,14 +273,18 @@ def cross_entropy_forward( world_size=world_size, ignore_idx=ignore_idx, n_cols=V, - n_non_ignore=n_rows, + n_rows=n_rows, # NEW: used only for all-gather offset + n_non_ignore=denom, # used only for gradient scaling in mean mode reduce_loss=reduce_loss, label_smoothing=label_smoothing, BLOCK_SIZE=BLOCK_SIZE, num_warps=32, ) - loss = torch.reshape(loss_1d, (B, SQ)) if not reduce_loss else (torch.sum(loss_1d) / n_rows) + if reduce_loss: + loss = torch.sum(loss_1d) / denom # average over non-ignored tokens + else: + loss = torch.reshape(loss_1d, (B, SQ)) return loss, _input @@ -343,10 +292,7 @@ def cross_entropy_forward( def cross_entropy_backward( _input: torch.Tensor, grad_output: torch.Tensor, is_cg_capturable: bool = False ): - """Backward implementation of cross entropy loss kernel""" - # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time - # Only check torch.equal when not in CUDA graph capturable mode if not is_cg_capturable and torch.equal( grad_output, torch.tensor(1.0, device=grad_output.device) ): @@ -356,10 +302,14 @@ def cross_entropy_backward( n_rows = B * SQ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + # ensure per-row scaling; make it 1D contiguous + grad_output_1d = grad_output.reshape(-1).contiguous() + element_mul_kernel[(n_rows,)]( _input, _input.stride(-2), - grad_output, + grad_output_1d, + grad_output_1d.stride(-1), # usually 1 V, BLOCK_SIZE=BLOCK_SIZE, num_warps=32, From 5ccad3f0d2eac8b39653372ba74255a1310bf2dc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 29 Aug 2025 18:58:56 +0000 Subject: [PATCH 02/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Casper --- transformer_engine/pytorch/triton/cross_entropy.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index dee9c98b36..78f60ac2c7 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -96,8 +96,8 @@ def cross_entropy_kernel( world_size, ignore_idx, n_cols, - n_rows, # NEW: total rows (B * SQ), used for all-gather offset - n_non_ignore, # number of non-ignored tokens, used for scaling + n_rows, # NEW: total rows (B * SQ), used for all-gather offset + n_non_ignore, # number of non-ignored tokens, used for scaling reduce_loss: tl.constexpr, label_smoothing: tl.constexpr, BLOCK_SIZE: tl.constexpr, @@ -273,8 +273,8 @@ def cross_entropy_forward( world_size=world_size, ignore_idx=ignore_idx, n_cols=V, - n_rows=n_rows, # NEW: used only for all-gather offset - n_non_ignore=denom, # used only for gradient scaling in mean mode + n_rows=n_rows, # NEW: used only for all-gather offset + n_non_ignore=denom, # used only for gradient scaling in mean mode reduce_loss=reduce_loss, label_smoothing=label_smoothing, BLOCK_SIZE=BLOCK_SIZE, From 988ff3079814feb67c83431123fe54bfa12e9671 Mon Sep 17 00:00:00 2001 From: Casper Date: Sat, 30 Aug 2025 09:50:27 +0200 Subject: [PATCH 03/10] fix comments Signed-off-by: Casper --- .../pytorch/triton/cross_entropy.py | 98 +++++++++++++++---- 1 file changed, 80 insertions(+), 18 deletions(-) diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index 78f60ac2c7..2f343e2016 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -96,12 +96,32 @@ def cross_entropy_kernel( world_size, ignore_idx, n_cols, - n_rows, # NEW: total rows (B * SQ), used for all-gather offset - n_non_ignore, # number of non-ignored tokens, used for scaling + n_rows, + n_non_ignore, reduce_loss: tl.constexpr, label_smoothing: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): + """ + This kernel computes both cross entropy loss and the gradient of the input. + Parameters: + X_ptr: Pointer to input tensor. + X_stride (int): The stride of the input tensor. + Y_ptr: Pointer to target tensor. + Y_stride (int): The stride of the target tensor. + loss_ptr: Pointer to tensor to store the loss. + loss_stride (int): The stride of the loss tensor. + m_d_X_y_ptr: Pointer to m/d/X_y tensor. + m_d_X_y_stride: The stride of m/d/X_y tensor. + rank (int): The rank of this device in the TP group. + world_size (int): The size of world involved in this distributed loss calculation. + ignore_idx (int): Tokens to be ignored for loss and gradient calculation. + n_cols (int): The number of columns in the input tensor. + n_rows (int): The number of rows in the input tensor. + n_non_ignore (int): The number of non-ignored elements in the batch. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + BLOCK_SIZE (int): The block size for Triton operations. + """ program_id = tl.program_id(0).to(tl.int64) # locate the start index @@ -116,19 +136,17 @@ def cross_entropy_kernel( for i in range(0, n_cols, BLOCK_SIZE): X_offsets = i + tl.arange(0, BLOCK_SIZE) tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols) - # leave loss at 0 (loss buffer is initialized to 0 in Python) return loss_ptr += program_id * loss_stride m_d_X_y_ptr += program_id * 3 * m_d_X_y_stride - # Reduce m/d/X_y across TP ranks + # Need to reduce the m/d/X_y values from other TP ranks m = tl.load(m_d_X_y_ptr) d = tl.load(m_d_X_y_ptr + m_d_X_y_stride) ori_X_y = tl.load(m_d_X_y_ptr + (2 * m_d_X_y_stride)) for i in range(1, world_size): - # IMPORTANT: hop by n_rows (NOT n_non_ignore) offset = i * 3 * n_rows * m_d_X_y_stride access_ptr = m_d_X_y_ptr + offset m_new = tl.load(access_ptr) @@ -139,43 +157,62 @@ def cross_entropy_kernel( m = tl.maximum(m, m_new) ori_X_y = tl.maximum(ori_X_y, X_y_new) - # Label smoothing setup + # Label smoothing is a general case of normal cross entropy scaled_x_sum = 0.0 eps = label_smoothing / (n_cols * world_size) - # Second pass: write gradients into X_ptr + # 4. [Online softmax] second pass: calculate the gradients + # dx_y = (softmax(x_y) - 1) / N + # dx_i = softmax(x_i) / N, i != y + # N is the number of non ignored elements in the batch + # For label smoothing: + # dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y + # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N + # = dx_i - (1 - label_smoothing) / N for i in range(0, n_cols, BLOCK_SIZE): X_offsets = i + tl.arange(0, BLOCK_SIZE) X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")) grad_dtype = X_block.dtype X_block = X_block.to(tl.float32) if label_smoothing > 0: + # scale X beforehand to avoid overflow scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) - + # Scale gradients based on reduction mode + # For reduce_loss=True: PyTorch will scale by 1/n_rows, so we need to scale by n_rows/n_non_ignore + # For reduce_loss=False: No additional scaling from PyTorch, so we don't scale here if reduce_loss: - # scale by the true number of non-ignored tokens - X_block = (tl.exp(X_block - m) / d - eps) / n_non_ignore + X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) else: X_block = tl.exp(X_block - m) / d - eps - tl.store(X_ptr + X_offsets, X_block.to(grad_dtype), mask=X_offsets < n_cols) + # We need tl.debug_barrier() to ensure the new result of X_ptr is written tl.debug_barrier() - # Loss for this row + # 5. Calculate the loss + # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) + # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) loss = -(ori_X_y - m - tl.log(d)) + + # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps + # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) + # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) + # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: + # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd)) + # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 if label_smoothing > 0: smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d)) loss = loss * (1 - label_smoothing) + smooth_loss - # Handle target position + # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N` vocab_start_idx = rank * n_cols vocab_end_idx = (rank + 1) * n_cols if y >= vocab_start_idx: if y < vocab_end_idx: X_y = tl.load(X_ptr + y - vocab_start_idx) + # Apply the same conditional scaling logic for the target token if reduce_loss: - X_y += -(1 - label_smoothing) / n_non_ignore + X_y += -(1 - label_smoothing) / (n_non_ignore) else: X_y += -(1 - label_smoothing) tl.store(X_ptr + y - vocab_start_idx, X_y) @@ -192,16 +229,33 @@ def element_mul_kernel( X_ptr, X_stride, grad_output_ptr, - grad_output_stride, # NEW + grad_output_stride, n_cols, BLOCK_SIZE: tl.constexpr, ): + """ + This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. + The multiplication is performed in-place on the tensor pointed by X_ptr. + Parameters: + X_ptr: Pointer to the input tensor. + X_stride (int): The stride of the input tensor. + grad_output_ptr: Pointer to the gradient output value. + grad_output_stride (int): The stride of the output value. + n_cols (int): The number of columns in the input tensor. + BLOCK_SIZE (int): The block size for Triton operations. + """ + + # Get the program ID and convert it to int64 to avoid overflow program_id = tl.program_id(0).to(tl.int64) + # Locate the start index X_ptr += program_id * X_stride + + # Load the gradient output value grad_output_ptr += program_id * grad_output_stride grad_output = tl.load(grad_output_ptr) + # Perform the element-wise multiplication for i in range(0, n_cols, BLOCK_SIZE): X_offsets = i + tl.arange(0, BLOCK_SIZE) X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) @@ -216,6 +270,7 @@ def cross_entropy_forward( dist_process_group: Union[dist.ProcessGroup, None], ignore_idx: int, ): + """Forward implementation of Cross Entropy kernel""" B, SQ, V = _input.shape n_rows = B * SQ @@ -223,9 +278,13 @@ def cross_entropy_forward( BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + # unreduced loss loss_1d = torch.zeros(n_rows, dtype=torch.float32, device=_input.device) + + # tensor to hold this rank's m/d/X_y values m_d_X_y = torch.zeros(n_rows * 3, dtype=torch.float32, device=_input.device) + # ensure _input and target are contiguous in the last dimension if _input.stride(-1) != 1: _input = _input.contiguous() if target.stride(-1) != 1: @@ -237,7 +296,7 @@ def cross_entropy_forward( X_ptr=_input, X_stride=_input.stride(-2), Y_ptr=target, - Y_stride=target.stride(-1), + Y_stride=target.stride(-1), # always 1 m_d_X_y_ptr=m_d_X_y, m_d_X_y_stride=m_d_X_y.stride(-1), rank=rank, @@ -273,8 +332,8 @@ def cross_entropy_forward( world_size=world_size, ignore_idx=ignore_idx, n_cols=V, - n_rows=n_rows, # NEW: used only for all-gather offset - n_non_ignore=denom, # used only for gradient scaling in mean mode + n_rows=n_rows, + n_non_ignore=denom, reduce_loss=reduce_loss, label_smoothing=label_smoothing, BLOCK_SIZE=BLOCK_SIZE, @@ -292,7 +351,10 @@ def cross_entropy_forward( def cross_entropy_backward( _input: torch.Tensor, grad_output: torch.Tensor, is_cg_capturable: bool = False ): + """Backward implementation of cross entropy loss kernel""" + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time + # Only check torch.equal when not in CUDA graph capturable mode if not is_cg_capturable and torch.equal( grad_output, torch.tensor(1.0, device=grad_output.device) ): From 8bea94659c4e29ce4556bcd5148977d943d3240a Mon Sep 17 00:00:00 2001 From: Casper Date: Sat, 30 Aug 2025 10:07:20 +0200 Subject: [PATCH 04/10] fix: few more style issues Signed-off-by: Casper --- transformer_engine/pytorch/triton/cross_entropy.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index 2f343e2016..41f2b4cecc 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -104,6 +104,7 @@ def cross_entropy_kernel( ): """ This kernel computes both cross entropy loss and the gradient of the input. + Parameters: X_ptr: Pointer to input tensor. X_stride (int): The stride of the input tensor. @@ -236,6 +237,7 @@ def element_mul_kernel( """ This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. The multiplication is performed in-place on the tensor pointed by X_ptr. + Parameters: X_ptr: Pointer to the input tensor. X_stride (int): The stride of the input tensor. @@ -271,6 +273,7 @@ def cross_entropy_forward( ignore_idx: int, ): """Forward implementation of Cross Entropy kernel""" + B, SQ, V = _input.shape n_rows = B * SQ @@ -340,10 +343,7 @@ def cross_entropy_forward( num_warps=32, ) - if reduce_loss: - loss = torch.sum(loss_1d) / denom # average over non-ignored tokens - else: - loss = torch.reshape(loss_1d, (B, SQ)) + loss = torch.reshape(loss_1d, (B, SQ)) if not reduce_loss else (torch.sum(loss_1d) / denom) return loss, _input From 2f365347914e2b0fc8ad06b7c730184e721b9d5e Mon Sep 17 00:00:00 2001 From: Casper Date: Sat, 30 Aug 2025 10:59:05 +0200 Subject: [PATCH 05/10] fix: remove grad_output_stride (unnecessary) Signed-off-by: Casper --- transformer_engine/pytorch/triton/cross_entropy.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index 41f2b4cecc..8505900af1 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -230,7 +230,6 @@ def element_mul_kernel( X_ptr, X_stride, grad_output_ptr, - grad_output_stride, n_cols, BLOCK_SIZE: tl.constexpr, ): @@ -242,7 +241,6 @@ def element_mul_kernel( X_ptr: Pointer to the input tensor. X_stride (int): The stride of the input tensor. grad_output_ptr: Pointer to the gradient output value. - grad_output_stride (int): The stride of the output value. n_cols (int): The number of columns in the input tensor. BLOCK_SIZE (int): The block size for Triton operations. """ @@ -254,7 +252,7 @@ def element_mul_kernel( X_ptr += program_id * X_stride # Load the gradient output value - grad_output_ptr += program_id * grad_output_stride + grad_output_ptr += program_id grad_output = tl.load(grad_output_ptr) # Perform the element-wise multiplication @@ -364,17 +362,13 @@ def cross_entropy_backward( n_rows = B * SQ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) - # ensure per-row scaling; make it 1D contiguous - grad_output_1d = grad_output.reshape(-1).contiguous() - element_mul_kernel[(n_rows,)]( _input, _input.stride(-2), - grad_output_1d, - grad_output_1d.stride(-1), # usually 1 + grad_output, V, BLOCK_SIZE=BLOCK_SIZE, num_warps=32, ) - return _input + return _input \ No newline at end of file From 82691858805f880b3ad8ede69d80203e5bd47362 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 30 Aug 2025 08:59:41 +0000 Subject: [PATCH 06/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/triton/cross_entropy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index 8505900af1..718638f55f 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -371,4 +371,4 @@ def cross_entropy_backward( num_warps=32, ) - return _input \ No newline at end of file + return _input From ac10da43a4801e3d10f49bdc6627b88f3f7b766a Mon Sep 17 00:00:00 2001 From: Casper Date: Sat, 30 Aug 2025 11:53:27 +0200 Subject: [PATCH 07/10] fix: only backward was broken Signed-off-by: Casper --- .../pytorch/triton/cross_entropy.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index 718638f55f..4ebc613ba0 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -96,7 +96,6 @@ def cross_entropy_kernel( world_size, ignore_idx, n_cols, - n_rows, n_non_ignore, reduce_loss: tl.constexpr, label_smoothing: tl.constexpr, @@ -118,11 +117,11 @@ def cross_entropy_kernel( world_size (int): The size of world involved in this distributed loss calculation. ignore_idx (int): Tokens to be ignored for loss and gradient calculation. n_cols (int): The number of columns in the input tensor. - n_rows (int): The number of rows in the input tensor. n_non_ignore (int): The number of non-ignored elements in the batch. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. BLOCK_SIZE (int): The block size for Triton operations. """ + program_id = tl.program_id(0).to(tl.int64) # locate the start index @@ -148,7 +147,7 @@ def cross_entropy_kernel( ori_X_y = tl.load(m_d_X_y_ptr + (2 * m_d_X_y_stride)) for i in range(1, world_size): - offset = i * 3 * n_rows * m_d_X_y_stride + offset = i * 3 * n_non_ignore * m_d_X_y_stride access_ptr = m_d_X_y_ptr + offset m_new = tl.load(access_ptr) d_new = tl.load(access_ptr + m_d_X_y_stride) @@ -191,6 +190,7 @@ def cross_entropy_kernel( tl.debug_barrier() # 5. Calculate the loss + # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) loss = -(ori_X_y - m - tl.log(d)) @@ -316,10 +316,6 @@ def cross_entropy_forward( else: m_d_X_y_gathered = m_d_X_y - # true number of non-ignored tokens - n_non_ignore = int((target != ignore_idx).sum().item()) - denom = max(n_non_ignore, 1) # avoid div-by-zero in degenerate batches - cross_entropy_kernel[(n_rows,)]( X_ptr=_input, X_stride=_input.stride(-2), @@ -333,15 +329,14 @@ def cross_entropy_forward( world_size=world_size, ignore_idx=ignore_idx, n_cols=V, - n_rows=n_rows, - n_non_ignore=denom, + n_non_ignore=n_rows, reduce_loss=reduce_loss, label_smoothing=label_smoothing, BLOCK_SIZE=BLOCK_SIZE, num_warps=32, ) - loss = torch.reshape(loss_1d, (B, SQ)) if not reduce_loss else (torch.sum(loss_1d) / denom) + loss = torch.reshape(loss_1d, (B, SQ)) if not reduce_loss else (torch.sum(loss_1d) / n_rows) return loss, _input @@ -371,4 +366,4 @@ def cross_entropy_backward( num_warps=32, ) - return _input + return _input \ No newline at end of file From c3c7b73655662953eea66af44001c807c07a834a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 30 Aug 2025 09:53:56 +0000 Subject: [PATCH 08/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/triton/cross_entropy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index 4ebc613ba0..a40377a292 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -366,4 +366,4 @@ def cross_entropy_backward( num_warps=32, ) - return _input \ No newline at end of file + return _input From 1b2b4a8e48978cc9fc0bd89c700628aaf8e5921b Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 4 Sep 2025 00:54:59 +0000 Subject: [PATCH 09/10] Generalize cross entropy backward kernel to handle reduced and unreduced loss Signed-off-by: Tim Moon --- tests/pytorch/test_parallel_cross_entropy.py | 58 ++++++++++++------- .../pytorch/triton/cross_entropy.py | 4 +- 2 files changed, 40 insertions(+), 22 deletions(-) diff --git a/tests/pytorch/test_parallel_cross_entropy.py b/tests/pytorch/test_parallel_cross_entropy.py index 77bea2b360..1e17395452 100644 --- a/tests/pytorch/test_parallel_cross_entropy.py +++ b/tests/pytorch/test_parallel_cross_entropy.py @@ -6,6 +6,7 @@ import torch from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy +from utils import dtype_tols class TestParallelCrossEntropy: @@ -18,19 +19,25 @@ def generate_infra(self, reduce_loss: bool, label_smoothing: float): label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none" ) - def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool): - + def generate_input( + self, + dtype: torch.dtype, + swap_dim: bool, + ignore_idx: bool, + device: torch.device = "cuda", + ): SQ = random.choice([64, 128]) batch = random.choice([1, 2]) vocab = random.choice([64000, 128000]) ignore = random.sample(range(0, SQ - 1), 5) + # Generate random data if swap_dim: - self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype).cuda() - self.tar_test = torch.randint(0, vocab, (SQ, batch)).cuda() + self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype, device=device) + self.tar_test = torch.randint(0, vocab, (SQ, batch), device=device) else: - self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype).cuda() - self.tar_test = torch.randint(0, vocab, (batch, SQ)).cuda() + self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype, device=device) + self.tar_test = torch.randint(0, vocab, (batch, SQ), device=device) if ignore_idx: for i in ignore: @@ -40,9 +47,14 @@ def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool): else: self.tar_test[0][i] = -100 + # Make copy of data for reference implementation self.input_ref = torch.reshape(self.input_test.clone().detach(), (batch * SQ, vocab)) self.tar_ref = torch.reshape(self.tar_test.clone().detach(), (batch * SQ,)) + # Enable autograd + self.input_test.requires_grad_() + self.input_ref.requires_grad_() + def one_iteration_test( self, dtype: torch.dtype, @@ -52,18 +64,20 @@ def one_iteration_test( ignore_idx: bool = False, ): + # Random data self.generate_input(dtype, swap_dim, ignore_idx) - self.input_test.requires_grad_(True) - self.input_ref.requires_grad_(True) - + # Forward pass test_loss = self.test_loss_func( self.input_test, self.tar_test, label_smoothing, reduce_loss, None ) - ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref) - # Handle backward pass based on the test scenario + # Compute square to avoid trivial backward pass + test_loss = torch.square(test_loss) + ref_loss = torch.square(ref_loss) + + # Backward pass if reduce_loss: test_loss.backward() ref_loss.backward() @@ -71,16 +85,18 @@ def one_iteration_test( test_loss.sum().backward() ref_loss.sum().backward() - test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss - - if ignore_idx: - print(test_loss, ref_loss) - - # Compare gradients when backward pass was called - torch.testing.assert_close( - torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad - ) - + # Check that loss and grad input match + tols = dtype_tols(dtype) + test_loss = test_loss.to(dtype=torch.float64, device="cpu") + ref_loss = test_loss.to(dtype=torch.float64, device="cpu") + ref_loss = ref_loss.reshape(test_loss.size()) + test_grad_input = self.input_test.grad.to(dtype=torch.float64, device="cpu") + ref_grad_input = self.input_ref.grad.to(dtype=torch.float64, device="cpu") + ref_grad_input = ref_grad_input.reshape(test_grad_input.size()) + torch.testing.assert_close(test_loss, ref_loss, **tols) + torch.testing.assert_close(test_grad_input, ref_grad_input, **tols) + + # Reset data self.input_test = None self.input_ref = None self.tar_test = None diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index a40377a292..7cfff1da9d 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -230,6 +230,7 @@ def element_mul_kernel( X_ptr, X_stride, grad_output_ptr, + grad_output_stride, n_cols, BLOCK_SIZE: tl.constexpr, ): @@ -252,7 +253,7 @@ def element_mul_kernel( X_ptr += program_id * X_stride # Load the gradient output value - grad_output_ptr += program_id + grad_output_ptr += program_id * grad_output_stride grad_output = tl.load(grad_output_ptr) # Perform the element-wise multiplication @@ -361,6 +362,7 @@ def cross_entropy_backward( _input, _input.stride(-2), grad_output, + 1 if grad_output.numel() > 1 else 0, V, BLOCK_SIZE=BLOCK_SIZE, num_warps=32, From 64e52cccf54dc4c417f8de23e6022d1544b9d5b7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 4 Sep 2025 00:55:39 +0000 Subject: [PATCH 10/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_parallel_cross_entropy.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/pytorch/test_parallel_cross_entropy.py b/tests/pytorch/test_parallel_cross_entropy.py index 1e17395452..fa56852ffc 100644 --- a/tests/pytorch/test_parallel_cross_entropy.py +++ b/tests/pytorch/test_parallel_cross_entropy.py @@ -8,6 +8,7 @@ from utils import dtype_tols + class TestParallelCrossEntropy: def generate_iters(self, iters: int):