Skip to content

Commit 06e684c

Browse files
authored
Adding cuda kernel (optimized for sm80) for block-wise 4b quantized float 16 GEMM. (#18619)
### Description Adding CUDA kernel for block-wise 4b quantized float 16 GEMM, this is specially optimized for Nvidia Ampere GPUs. ### Motivation and Context Trying to improve quantized LLM inference performance on Nvidia Ampere GPUs ### Note: This is implemented by extending CUTLASS, so it has a hard dependency on CUTLASS. However, in current build system, loading of CUTLASS dependency is guarded with: (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION) If both of these options are turned off, then compilation will fail. Why CUTLASS dependency is guarded at all? It's a header file only library that does not introduce any binary if not instantiated. What's the downside of removing all the guards and just include CUTLASS unconditionally?
1 parent bdf678d commit 06e684c

25 files changed

+6257
-513
lines changed

.lintrunner.toml

+1
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ exclude_patterns = [
132132
'onnxruntime/core/flatbuffers/schema/*.fbs.h', # Generated code
133133
'onnxruntime/core/graph/contrib_ops/quantization_defs.cc',
134134
'onnxruntime/core/mlas/**', # Contains assembly code
135+
'onnxruntime/core/mickey/cutlass_ext/**', # CUTLASS lib recommends NO automatic code formatting
135136
'winml/lib/Api.Image/shaders/**', # Contains data chunks
136137
]
137138
command = [

cmake/CMakeLists.txt

+4-1
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,9 @@ if (onnxruntime_USE_CUDA)
727727
set(onnxruntime_USE_FLASH_ATTENTION OFF)
728728
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
729729
endif()
730+
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.4)
731+
message( FATAL_ERROR "Failed build due to CUDA compiler version < 11.4")
732+
endif()
730733
else()
731734
set(onnxruntime_USE_FLASH_ATTENTION OFF)
732735
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
@@ -747,8 +750,8 @@ if (onnxruntime_USE_CUDA)
747750
list(APPEND ORT_PROVIDER_FLAGS -DUSE_MEMORY_EFFICIENT_ATTENTION=1)
748751
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_MEMORY_EFFICIENT_ATTENTION=1)
749752
endif()
750-
751753
endif()
754+
752755
if (onnxruntime_USE_VITISAI)
753756
list(APPEND ORT_PROVIDER_FLAGS -DUSE_VITISAI=1)
754757
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_VITISAI=1)

cmake/onnxruntime_providers_cuda.cmake

+1-1
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@
201201
endif()
202202

203203
include(cutlass)
204-
target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples)
204+
target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples ${cutlass_SOURCE_DIR}/tools/util/include)
205205

206206
target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES}
207207
PUBLIC ${CUDAToolkit_INCLUDE_DIRS})

cmake/onnxruntime_unittests.cmake

+1
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,7 @@ if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS)
774774
onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_ut ${onnxruntime_test_providers_cuda_ut_src} $<TARGET_OBJECTS:onnxruntime_providers_cuda_obj>)
775775
config_cuda_provider_shared_module(onnxruntime_providers_cuda_ut)
776776
onnxruntime_add_include_to_target(onnxruntime_providers_cuda_ut GTest::gtest GTest::gmock)
777+
target_include_directories(onnxruntime_providers_cuda_ut PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey)
777778
target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common)
778779
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_cuda_ut)
779780
endif()

onnxruntime/core/mickey/README.md

+4
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,7 @@ Playful name for a template library of high performance cuda code that
44
are often shared by various AI operators. The intention is to make this
55
header files only, with no binary impact unless it is instantiated
66
where it is needed.
7+
8+
Currently cuda code are scattered in multiple locations in the repo.
9+
Hopefully this can be the starting point of consolidating all cuda
10+
code.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
/**
2+
* Copyright (c) Microsoft Corporation. All rights reserved.
3+
* Licensed under the MIT License.
4+
*
5+
* Module Name:
6+
* blk_q4/f16_gemm_sm80.h
7+
*
8+
* Abstract:
9+
* Entry point for Q4F16 GEMM kernel for SM80 devices.
10+
*/
11+
12+
#pragma once
13+
14+
#include "cutlass/cutlass.h"
15+
#include "cutlass_ext/q4gemm/device/quantb_gemm.h"
16+
17+
namespace onnxruntime {
18+
namespace cuda {
19+
20+
//
21+
// This is the implementation of the quantized GEMM kernel for 16b float x blocked quantized 4b data type
22+
//
23+
template <
24+
typename ElementDequant_, // <- data type of dequantized elements for gemm, fp16 or bf16
25+
typename QuantBlocking_, // <- weights block per scale, cutlass::MatrixShape<x,y>
26+
bool SmallM, // <- true if M <= 16
27+
bool kHasQuantOffset>
28+
struct BlkQ4F16GemmImpl {
29+
//
30+
// Type definitions
31+
//
32+
33+
using ElementDequant = ElementDequant_;
34+
using QuantBlocking = QuantBlocking_;
35+
36+
static_assert(sizeof(ElementDequant) == 2, "q4f16gemm kerenl only support 16b operands!");
37+
38+
// Data types that are fixed for this kernel
39+
using ElementAccumulator = float;
40+
using ElementComputeEpilogue = ElementAccumulator;
41+
using ElementInputA = ElementDequant;
42+
using ElementOutput = ElementDequant;
43+
44+
using ElementW = uint8_t; // <- Weight is int4, uint8 for two of them
45+
46+
// We pack 4 weights into one 16b element, so as to leverage cutlass tile iterators
47+
// for async shared memory loading and minimize bank conflict
48+
using ElementWPack = ElementDequant;
49+
50+
using ElementQScale = ElementDequant; // <- data type of quantization scale
51+
using ElementQOffset = uint8_t;
52+
53+
using LayoutInputA = cutlass::layout::RowMajor;
54+
using LayoutInputWPack = cutlass::layout::ColumnMajor;
55+
using LayoutOutput = cutlass::layout::RowMajor;
56+
57+
// Layout of quantization scale and offset, oriented to be loaded using less instructions
58+
// in a warp tile
59+
using LayoutInputQScale =
60+
typename std::conditional<QuantBlocking::kRow == 1,
61+
cutlass::layout::ColumnMajor,
62+
cutlass::layout::RowMajor>::type; // <- layout of quantization scale
63+
64+
using ShapeMMAThreadBlock =
65+
typename std::conditional<SmallM,
66+
cutlass::gemm::GemmShape<16, 64, 64>,
67+
cutlass::gemm::GemmShape<128, 256, 64>>::type;
68+
69+
static constexpr int MinN = QuantBlocking::kColumn > 32 ? QuantBlocking::kColumn : 32;
70+
using ShapeMMAWarp =
71+
typename std::conditional<SmallM,
72+
cutlass::gemm::GemmShape<16, MinN, 64>,
73+
cutlass::gemm::GemmShape<64, 64, 64>>::type;
74+
75+
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>;
76+
77+
// This code section describes how threadblocks are scheduled on GPU
78+
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
79+
80+
// This code section describes the epilogue part of the kernel
81+
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
82+
ElementOutput, // <- data type of output matrix
83+
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- the number of elements per vectorized
84+
// memory access. For a byte, it's 16
85+
// elements. This becomes the vector width of
86+
// math instructions in the epilogue too
87+
ElementAccumulator, // <- data type of accumulator
88+
ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function
89+
90+
// Number of pipelines you want to use
91+
static constexpr int NumStages = 3;
92+
93+
using Gemm = cutlass::gemm::device::QuantBGemm<
94+
ElementInputA,
95+
LayoutInputA,
96+
ElementWPack,
97+
LayoutInputWPack,
98+
ElementQScale,
99+
typename std::conditional<kHasQuantOffset, ElementQOffset, std::monostate>::type,
100+
LayoutInputQScale,
101+
QuantBlocking,
102+
ElementOutput,
103+
LayoutOutput,
104+
ElementAccumulator,
105+
cutlass::arch::OpClassTensorOp,
106+
cutlass::arch::Sm80,
107+
ShapeMMAThreadBlock,
108+
ShapeMMAWarp,
109+
ShapeMMAOp,
110+
EpilogueOp,
111+
SwizzleThreadBlock,
112+
NumStages>;
113+
114+
using Arguments = typename Gemm::Arguments;
115+
116+
// Invoke gemm kernel (the version with quantization offset)
117+
static cutlass::Status run(
118+
cudaStream_t stream,
119+
const cutlass::gemm::GemmCoord& problem_size_,
120+
cutlass::TensorRef<ElementInputA const, LayoutInputA> ref_A_,
121+
cutlass::TensorRef<ElementWPack const, LayoutInputWPack> ref_B_,
122+
cutlass::TensorRef<ElementQScale const, LayoutInputQScale> ref_Qscale_,
123+
cutlass::TensorRef<ElementQOffset const, LayoutInputQScale> ref_Qoffset_,
124+
cutlass::TensorRef<ElementOutput const, LayoutOutput> ref_C_,
125+
cutlass::TensorRef<ElementOutput, LayoutOutput> ref_D_,
126+
typename EpilogueOp::Params epilogue_ = typename EpilogueOp::Params()) {
127+
if constexpr (!kHasQuantOffset) {
128+
return cutlass::Status::kErrorNotSupported;
129+
} else {
130+
if constexpr (ShapeMMAThreadBlock::kM == 16) {
131+
if (problem_size_.m() > 16) {
132+
// For M > 16, the caller should have picked the
133+
// kernel with bigger M
134+
return cutlass::Status::kErrorNotSupported;
135+
}
136+
}
137+
138+
// Construct Gemm arguments
139+
Arguments args{
140+
problem_size_,
141+
ref_A_,
142+
ref_B_,
143+
ref_Qscale_,
144+
ref_Qoffset_,
145+
ref_C_,
146+
ref_D_,
147+
epilogue_};
148+
149+
Gemm gemm_op;
150+
151+
// Check if this GEMM can be run or not
152+
cutlass::Status status = gemm_op.can_implement(args);
153+
if (status != cutlass::Status::kSuccess) {
154+
return status;
155+
}
156+
157+
// Launch the CUTLASS GEMM kernel.
158+
return gemm_op(args, nullptr, stream);
159+
}
160+
}
161+
162+
// Invoke gemm kernel (the version without quantization offset)
163+
static cutlass::Status run(
164+
cudaStream_t stream,
165+
const cutlass::gemm::GemmCoord& problem_size_,
166+
cutlass::TensorRef<ElementInputA const, LayoutInputA> ref_A_,
167+
cutlass::TensorRef<ElementWPack const, LayoutInputWPack> ref_B_,
168+
cutlass::TensorRef<ElementQScale const, LayoutInputQScale> ref_Qscale_,
169+
cutlass::TensorRef<ElementOutput const, LayoutOutput> ref_C_,
170+
cutlass::TensorRef<ElementOutput, LayoutOutput> ref_D_,
171+
typename EpilogueOp::Params epilogue_ = typename EpilogueOp::Params()) {
172+
if constexpr (kHasQuantOffset) {
173+
return cutlass::Status::kErrorNotSupported;
174+
} else {
175+
if constexpr (ShapeMMAThreadBlock::kM == 16) {
176+
if (problem_size_.m() > 16) {
177+
// For M > 16, the caller should have picked the
178+
// kernel with bigger M
179+
return cutlass::Status::kErrorNotSupported;
180+
}
181+
}
182+
183+
// Construct Gemm arguments
184+
Arguments args{
185+
problem_size_,
186+
ref_A_,
187+
ref_B_,
188+
ref_Qscale_,
189+
ref_C_,
190+
ref_D_,
191+
epilogue_};
192+
193+
Gemm gemm_op;
194+
195+
// Check if this GEMM can be run or not
196+
cutlass::Status status = gemm_op.can_implement(args);
197+
if (status != cutlass::Status::kSuccess) {
198+
return status;
199+
}
200+
201+
// Launch the CUTLASS GEMM kernel.
202+
return gemm_op(args, nullptr, stream);
203+
}
204+
}
205+
};
206+
207+
} // namespace cuda
208+
} // namespace onnxruntime

onnxruntime/core/mickey/blk_q4/prepack_sm80.h renamed to onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
* Licensed under the MIT License.
44
*
55
* Module Name:
6-
* prepack_sm80.h
6+
* blk_q4/f16_prepack_sm80.h
77
*
88
* Abstract:
99
* Prepack weights and quantization parameters (scales and offsets) for

0 commit comments

Comments
 (0)