Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,20 @@ inline size_t aligned(size_t value, size_t alignment = 16) {
template <typename T>
struct KernelTraits;

template <>
struct KernelTraits<cutlass::half_t> {
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
using TileShape = Shape<_128, _256, _64>; // Threadblock-level tile size
using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster
};

template <>
struct KernelTraits<cutlass::bfloat16_t> {
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
using TileShape = Shape<_128, _256, _64>; // Threadblock-level tile size
using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster
};

template <typename ElementA, typename ElementB, typename ElementC,
typename LayoutA = cutlass::layout::RowMajor,
typename LayoutB = cutlass::layout::ColumnMajor,
Expand Down
14 changes: 0 additions & 14 deletions src/runtime/contrib/cutlass/fp16_group_gemm_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,6 @@ struct CutlassGroupGemm<90, ElementA, ElementB, ElementC> {
}
};

template <>
struct KernelTraits<cutlass::half_t> {
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
using TileShape = Shape<_128, _256, _64>; // Threadblock-level tile size
using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster
};

template <>
struct KernelTraits<cutlass::bfloat16_t> {
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
using TileShape = Shape<_128, _256, _64>; // Threadblock-level tile size
using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster
};

void tvm_cutlass_group_gemm_sm90(NDArray x, NDArray weight, NDArray indptr, NDArray workspace,
NDArray out) {
tvm_cutlass_group_gemm_impl<90>(x, weight, indptr, workspace, out);
Expand Down
12 changes: 6 additions & 6 deletions src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@

#include <cuda_fp16.h>
#include <float.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/function.h>
#include <tvm/runtime/ndarray.h>

#include "fp16_group_gemm_runner_sm90.cuh"

Expand Down Expand Up @@ -60,10 +59,11 @@ void tvm_cutlass_fp8_group_gemm(NDArray x, NDArray weight, NDArray indptr, NDArr
int n = weight->shape[1];
int k = x->shape[1];
const float* beta = nullptr;
cutlass_group_gemm(static_cast<ElementA*>(x->data), static_cast<ElementB*>(weight->data),
static_cast<int64_t*>(indptr->data), static_cast<uint8_t*>(workspace->data),
workspace->shape[0], n, k, num_groups, static_cast<float*>(alpha->data), beta,
static_cast<ElementC*>(out->data), stream);
cutlass_group_gemm_sm90(static_cast<ElementA*>(x->data), static_cast<ElementB*>(weight->data),
static_cast<int64_t*>(indptr->data),
static_cast<uint8_t*>(workspace->data), workspace->shape[0], n, k,
num_groups, static_cast<float*>(alpha->data), beta,
static_cast<ElementC*>(out->data), stream);
}

TVM_FFI_REGISTER_GLOBAL("cutlass.group_gemm_e5m2_e5m2_fp16")
Expand Down
Loading