Skip to content

[deepseek][blackwell] add Cutlass cute dsl blackwell dense based looping group gemm #1274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchtitan/experiments/deepseek_v3/generate.py
Original file line number Diff line number Diff line change
@@ -224,7 +224,7 @@ def generate(
tokenizer,
dist_config,
messages: list[dict],
n_tokens: int = 200,
n_tokens: int = 50,
):
rank = dist.get_rank()
device = dist_config.device
419 changes: 419 additions & 0 deletions torchtitan/experiments/deepseek_v3/group_gemms.py
Original file line number Diff line number Diff line change
@@ -47,6 +47,24 @@
except ImportError:
TRITON_CONTIGUOUS_GROUP_GEMM_AVAILABLE = False

# Cutlass Cute DSL
try:
import cuda.bindings.driver as cuda
import cutlass
import cutlass.cute as cute

# import cutlass.torch as cutlass_torch
# import cutlass.utils as utils
from cutlass.cute.runtime import from_dlpack

from torchtitan.experiments.kernels.blackwell.cute_dense_gemm import DenseGemmKernel

CUTLASS_AVAILABLE = True
except ImportError as e:
CUTLASS_AVAILABLE = False
print(f"Cutlass imports not available: {e}`")
print("Please run `pip install nvidia-cutlass-dsl`")


# Strategy base class for GroupGEMM implementations
class GroupGEMMStrategy:
@@ -97,9 +115,410 @@ def is_available() -> bool:
"TorchBF16GroupGEMM",
"TorchAOBF16GroupGEMM",
"TritonCGBF16GroupGEMM",
"CuteDenseLoopingGroupGEMM",
]


# requires pip install nvidia-cutlass-dsl
class CuteDenseLoopingGroupGEMM(GroupGEMMStrategy):
"""
Implementation of grouped GEMM using Blackwell Dense GEMM kernel with manual looping.
High level overview:
- Compiled kernels via Kernel caching: Compiled kernels are cached and reused
- Expert token tensor reuse: For MoE forward pass, expert_tokens are converted to CUTE
format once and reused for both gate and up projections
- Backup: Falls back to PyTorch implementation if CUTE kernels fail
"""

def __init__(self, custom_activation):

super().__init__(custom_activation)

# Kernel configuration
self.alignment = 16
self.dtype = torch.bfloat16
self.cutlass_dtype = cutlass.BFloat16

# Initialize Cute Dense GEMM kernel
try:
self.gemm_kernel = DenseGemmKernel(
acc_dtype=cutlass.Float32,
use_2cta_instrs=False,
mma_tiler_mn=(128, 128),
cluster_shape_mn=(1, 1),
use_tma_store=False,
)
except Exception as e:
raise RuntimeError(f"Failed to initialize GEMM kernel: {e}") from e

# Setup CUDA stream
torch_stream = torch.cuda.Stream()
self.stream = cuda.CUstream(torch_stream.cuda_stream)

# Cache for compiled kernels
self._compiled_kernels = {}

# debug monitoring
self.debug_mode = True

def arrange_expert_weights(self, all_weights, submod_name, module):
"""Store weights in a simple list format"""
return torch.stack(all_weights)

def _create_cute_tensor(self, tensor: torch.Tensor) -> cute.Tensor:
"""
Convert a PyTorch tensor to a CUTE tensor with proper formatting.
Args:
tensor: PyTorch tensor to convert
Returns:
CUTE tensor ready for kernel execution
"""
if not tensor.is_contiguous():
tensor = tensor.contiguous()

# Convert to MNKL format
tensor_mnkl = tensor.unsqueeze(-1).contiguous().detach()

# Create CUTE tensor
cute_tensor = from_dlpack(tensor_mnkl, assumed_align=self.alignment)
cute_tensor.element_type = self.cutlass_dtype
cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=1)

return cute_tensor

def _get_or_compile_kernel(self, a_cute, b_cute, c_cute, operation_name: str):
"""
Get a compiled kernel from cache or compile a new one.
Args:
a_cute: Input tensor A in CUTE format
b_cute: Input tensor B in CUTE format
c_cute: Output tensor C in CUTE format
operation_name: Name of the operation for caching
Returns:
Compiled CUTE kernel
"""
cache_key = f"{operation_name}_{a_cute.shape}_{b_cute.shape}_{c_cute.shape}"

if cache_key not in self._compiled_kernels:
try:
self._compiled_kernels[cache_key] = cute.compile(
self.gemm_kernel, a_cute, b_cute, c_cute, self.stream
)
if self.debug_mode:
print(f"✓ Compiled kernel for {operation_name}")
except Exception as e:
raise RuntimeError(
f"Failed to compile {operation_name} kernel: {e}"
) from e

return self._compiled_kernels[cache_key]

def _execute_gemm_operation(
self, input_tensor: torch.Tensor, weight: torch.Tensor, operation_name: str
) -> torch.Tensor:
"""
Execute a single GEMM operation using cute dense kernel.
Args:
input_tensor: Input tensor [M, K]
weight: Weight tensor [N, K]
operation_name: Name of the operation for debugging
Returns:
Output tensor [M, N]
"""
batch_size, input_dim = input_tensor.shape
output_dim = weight.shape[0]

# Create output tensor
output = torch.zeros(
(batch_size, output_dim),
device=input_tensor.device,
dtype=self.dtype,
requires_grad=False,
)

# Convert tensors to cute format
try:
a_cute = self._create_cute_tensor(input_tensor)
b_cute = self._create_cute_tensor(weight)
c_cute = self._create_cute_tensor(output)
except Exception as e:
raise RuntimeError(
f"Failed to create CUTE tensors for {operation_name}: {e}"
) from e

# Get or compile kernel
compiled_kernel = self._get_or_compile_kernel(
a_cute, b_cute, c_cute, operation_name
)

# Execute kernel
try:
compiled_kernel(a_cute, b_cute, c_cute, self.stream)
if self.debug_mode:
print(f"✓ Executed {operation_name} kernel successfully")
except Exception as e:
raise RuntimeError(f"Failed to execute {operation_name} kernel: {e}") from e

return output.squeeze(-1) if output.dim() > 2 else output

def _execute_gemm_with_cute_input(
self,
input_cute: cute.Tensor,
weight: torch.Tensor,
operation_name: str,
output_shape: tuple,
) -> torch.Tensor:
"""
Execute a GEMM operation with pre-converted cute input tensor.
Args:
input_cute: Input cute tensor (already in cute format)
weight: Weight tensor [N, K]
operation_name: Name of the operation for debugging
output_shape: Shape of output tensor (batch_size, output_dim)
Returns:
Output tensor [M, N]
"""
batch_size, output_dim = output_shape

# Create output tensor
output = torch.zeros(
(batch_size, output_dim),
device=weight.device,
dtype=self.dtype,
requires_grad=False,
)

# Convert weight and output tensors to CUTE format
try:
b_cute = self._create_cute_tensor(weight)
c_cute = self._create_cute_tensor(output)
except Exception as e:
raise RuntimeError(
f"Failed to create CUTE tensors for {operation_name}: {e}"
) from e

# Get or compile kernel
compiled_kernel = self._get_or_compile_kernel(
input_cute, b_cute, c_cute, operation_name
)

# Execute kernel
try:
compiled_kernel(input_cute, b_cute, c_cute, self.stream)
if self.debug_mode:
print(f"✓ Executed {operation_name} kernel successfully")
except Exception as e:
raise RuntimeError(f"Failed to execute {operation_name} kernel: {e}") from e

return output.squeeze(-1) if output.dim() > 2 else output

def _process_expert(
self,
expert_tokens: torch.Tensor,
expert_idx: int,
w_gate: torch.Tensor,
w_up: torch.Tensor,
w_down: torch.Tensor,
) -> torch.Tensor:
"""
Process tokens through a single expert using cute dense kernels.
Args:
expert_tokens: Tokens for this expert [num_tokens, hidden_size]
expert_idx: Index of the expert
w_gate: Gate projection weights [intermediate_size, hidden_size]
w_up: Up projection weights [intermediate_size, hidden_size]
w_down: Down projection weights [hidden_size, intermediate_size]
Returns:
Expert output [num_tokens, hidden_size]
"""
num_tokens = expert_tokens.shape[0]
intermediate_size = w_gate.shape[0]
hidden_size = w_down.shape[0]

# Convert expert_tokens to CUTE format once for reuse
# OPTIMIZATION: Gate and up projections share the same input tensor,
# so we convert to CUTE format once and reuse to avoid redundant overhead
try:
expert_tokens_cute = self._create_cute_tensor(expert_tokens)
except BaseException as e:
raise RuntimeError(
f"Failed to create CUTE tensor for expert {expert_idx} input: {e}"
) from e

# Gate projection - reuse the CUTE input tensor
gate_out = self._execute_gemm_with_cute_input(
expert_tokens_cute,
w_gate,
f"gate_expert_{expert_idx}",
(num_tokens, intermediate_size),
)

# Up projection - reuse the same CUTE input tensor
up_out = self._execute_gemm_with_cute_input(
expert_tokens_cute,
w_up,
f"up_expert_{expert_idx}",
(num_tokens, intermediate_size),
)

# Apply activation and combine
hidden = self.activation_function(gate_out) * up_out

# Down projection - create new CUTE tensor for hidden state
expert_output = self._execute_gemm_operation(
hidden, w_down, f"down_expert_{expert_idx}"
)

return expert_output

def execute(self, contig_tokens, m_sizes, m_offsets, module):
"""
Execute the complete grouped GEMM operation via looping.
Args:
contig_tokens: Input tokens arranged contiguously by expert
m_sizes: Sizes of each group
m_offsets: Offsets of each group
module: MoE module containing weights and parameters
Returns:
Processed tokens
"""
try:
# Get weights
device = contig_tokens.device
w_gate = module.get_parameter("gate_proj_weight")
w_up = module.get_parameter("up_proj_weight")
w_down = module.get_parameter("down_proj_weight")

# Validate inputs
if len(m_sizes) != w_gate.shape[0]:
raise ValueError(
f"Number of experts mismatch: {len(m_sizes)} vs {w_gate.shape[0]}"
)

# Prepare output tensor
hidden_size = w_gate.shape[2] if len(w_gate.shape) > 2 else w_gate.shape[1]
output = torch.zeros(
contig_tokens.shape[0],
hidden_size,
dtype=contig_tokens.dtype,
device=device,
)

# Process each expert
offset = 0
active_experts = 0

for expert_idx, size in enumerate(m_sizes):
if size > 0:
# Get tokens and weights for this expert
expert_tokens = contig_tokens[offset : offset + size]
expert_gate_weight = w_gate[expert_idx]
expert_up_weight = w_up[expert_idx]
expert_down_weight = w_down[expert_idx]

# Process through expert
expert_output = self._process_expert(
expert_tokens,
expert_idx,
expert_gate_weight,
expert_up_weight,
expert_down_weight,
)

# Store results
output[offset : offset + size] = expert_output
active_experts += 1

offset += size

if self.debug_mode:
print(
f"Processed {active_experts} active experts out of {len(m_sizes)} total"
)

return output

except Exception as e:
# Fallback to PyTorch implementation on error
if self.debug_mode:
print(f"CUTE kernel failed, falling back to PyTorch: {e}")
return self._fallback_pytorch(contig_tokens, m_sizes, module)

def _fallback_pytorch(self, contig_tokens, m_sizes, module):
"""
Fallback implementation using standard PyTorch operations.
Args:
contig_tokens: Input tokens
m_sizes: Group sizes
module: MoE module
Returns:
Processed tokens using PyTorch mm
"""
print("\nWARNING: Cute GEMM issue -- Falling back to PyTorch implementation\n")
device = contig_tokens.device
w_gate = module.get_parameter("gate_proj_weight")
w_up = module.get_parameter("up_proj_weight")
w_down = module.get_parameter("down_proj_weight")

hidden_size = w_gate.shape[2] if len(w_gate.shape) > 2 else w_gate.shape[1]
output = torch.zeros(
contig_tokens.shape[0],
hidden_size,
dtype=contig_tokens.dtype,
device=device,
)

offset = 0
for expert_idx, size in enumerate(m_sizes):
if size > 0:
expert_tokens = contig_tokens[offset : offset + size]

# Standard PyTorch forward pass
gate_out = torch.mm(expert_tokens, w_gate[expert_idx].t())
up_out = torch.mm(expert_tokens, w_up[expert_idx].t())
hidden = self.activation_function(gate_out) * up_out
expert_output = torch.mm(hidden, w_down[expert_idx].t())

output[offset : offset + size] = expert_output

offset += size

return output

def clear_cache(self):
"""Clear the compiled kernel cache to free memory."""
self._compiled_kernels.clear()
if self.debug_mode:
print("Cleared compiled kernel cache")

def set_debug_mode(self, enabled: bool = False):
"""Enable or disable debug mode."""
self.debug_mode = enabled

@staticmethod
def is_available() -> bool:
"""Check if this strategy is available on the current system."""
try:
return CUTLASS_AVAILABLE and torch.cuda.is_available()
except Exception:
return False


class TritonCGBF16GroupGEMM(GroupGEMMStrategy):
"""Implementation of Triton Contiguous group Gemm"""

9 changes: 8 additions & 1 deletion torchtitan/experiments/deepseek_v3/model.py
Original file line number Diff line number Diff line change
@@ -45,6 +45,7 @@
from attn_mask_utils import _prepare_4d_causal_attention_mask

from group_gemms import (
CuteDenseLoopingGroupGEMM,
DSGroupGEMM,
TorchAOBF16GroupGEMM,
TorchBF16GroupGEMM,
@@ -474,7 +475,8 @@ class MoE(nn.Module):
# Group GEMM strategies
group_gemm_strategies = None
# which group gemm to use?
group_mm = "torch" # fp8 options = ["torchfp8", "dsgemm"] bf16 = ["torch", , "torchao", "tritoncg"]
group_mm = "cute" # fp8 options = ["torchfp8", "dsgemm"] bf16 = ["torch", , "torchao", "tritoncg", "cute"]
print(f"Using group gemm strategy: {group_mm}")

def __init__(self, config):
super().__init__()
@@ -550,6 +552,11 @@ def _initialize_group_gemm_strategies(cls):
if TritonCGBF16GroupGEMM.is_available()
else None
),
"cute": (
CuteDenseLoopingGroupGEMM(MLP.act_fn)
if CuteDenseLoopingGroupGEMM.is_available()
else None
),
}

def combine_experts(self, submod_name: str):
1 change: 1 addition & 0 deletions torchtitan/experiments/deepseek_v3/requirements.txt
Original file line number Diff line number Diff line change
@@ -3,3 +3,4 @@ accelerate
torchdata >= 0.8.0
datasets >= 2.21.0
tomli >= 1.1.0 ; python_version < "3.11"
nvidia-cutlass-dsl
1,934 changes: 1,934 additions & 0 deletions torchtitan/experiments/kernels/blackwell/cute_dense_gemm.py

Large diffs are not rendered by default.