From 66edd7454d8b9bf491035b6029b260c349b412ed Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 16 Jun 2025 15:22:23 -0400 Subject: [PATCH] [CUTLASS] Fix CUTLASS kernel build on Hopper The cutlass kernel build on Hopper GPU was broken since #18033. This PR fixes the issue. --- .../cutlass/fp16_group_gemm_runner_sm90.cuh | 14 ++++++++++++++ .../contrib/cutlass/fp16_group_gemm_sm90.cu | 14 -------------- src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu | 12 ++++++------ 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh b/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh index 38e1beb2b8f4..246063ca0341 100644 --- a/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh +++ b/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh @@ -57,6 +57,20 @@ inline size_t aligned(size_t value, size_t alignment = 16) { template struct KernelTraits; +template <> +struct KernelTraits { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; + using TileShape = Shape<_128, _256, _64>; // Threadblock-level tile size + using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster +}; + +template <> +struct KernelTraits { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; + using TileShape = Shape<_128, _256, _64>; // Threadblock-level tile size + using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster +}; + template { } }; -template <> -struct KernelTraits { - using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; - using TileShape = Shape<_128, _256, _64>; // Threadblock-level tile size - using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster -}; - -template <> -struct KernelTraits { - using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; - using TileShape = Shape<_128, _256, _64>; // Threadblock-level tile size - using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster -}; - void tvm_cutlass_group_gemm_sm90(NDArray x, NDArray weight, NDArray indptr, NDArray workspace, NDArray out) { tvm_cutlass_group_gemm_impl<90>(x, weight, indptr, workspace, out); diff --git a/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu b/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu index 686a6ebcffeb..0eaa6a1efb77 100644 --- a/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu +++ b/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu @@ -19,9 +19,8 @@ #include #include -#include -#include #include +#include #include "fp16_group_gemm_runner_sm90.cuh" @@ -60,10 +59,11 @@ void tvm_cutlass_fp8_group_gemm(NDArray x, NDArray weight, NDArray indptr, NDArr int n = weight->shape[1]; int k = x->shape[1]; const float* beta = nullptr; - cutlass_group_gemm(static_cast(x->data), static_cast(weight->data), - static_cast(indptr->data), static_cast(workspace->data), - workspace->shape[0], n, k, num_groups, static_cast(alpha->data), beta, - static_cast(out->data), stream); + cutlass_group_gemm_sm90(static_cast(x->data), static_cast(weight->data), + static_cast(indptr->data), + static_cast(workspace->data), workspace->shape[0], n, k, + num_groups, static_cast(alpha->data), beta, + static_cast(out->data), stream); } TVM_FFI_REGISTER_GLOBAL("cutlass.group_gemm_e5m2_e5m2_fp16")