|
| 1 | +/** |
| 2 | + * Copyright (c) Microsoft Corporation. All rights reserved. |
| 3 | + * Licensed under the MIT License. |
| 4 | + * |
| 5 | + * Module Name: |
| 6 | + * blk_q4/f16_gemm_sm80.h |
| 7 | + * |
| 8 | + * Abstract: |
| 9 | + * Entry point for Q4F16 GEMM kernel for SM80 devices. |
| 10 | + */ |
| 11 | + |
| 12 | +#pragma once |
| 13 | + |
| 14 | +#include "cutlass/cutlass.h" |
| 15 | +#include "cutlass_ext/q4gemm/device/quantb_gemm.h" |
| 16 | + |
| 17 | +namespace onnxruntime { |
| 18 | +namespace cuda { |
| 19 | + |
| 20 | +// |
| 21 | +// This is the implementation of the quantized GEMM kernel for 16b float x blocked quantized 4b data type |
| 22 | +// |
| 23 | +template < |
| 24 | + typename ElementDequant_, // <- data type of dequantized elements for gemm, fp16 or bf16 |
| 25 | + typename QuantBlocking_, // <- weights block per scale, cutlass::MatrixShape<x,y> |
| 26 | + bool SmallM, // <- true if M <= 16 |
| 27 | + bool kHasQuantOffset> |
| 28 | +struct BlkQ4F16GemmImpl { |
| 29 | + // |
| 30 | + // Type definitions |
| 31 | + // |
| 32 | + |
| 33 | + using ElementDequant = ElementDequant_; |
| 34 | + using QuantBlocking = QuantBlocking_; |
| 35 | + |
| 36 | + static_assert(sizeof(ElementDequant) == 2, "q4f16gemm kerenl only support 16b operands!"); |
| 37 | + |
| 38 | + // Data types that are fixed for this kernel |
| 39 | + using ElementAccumulator = float; |
| 40 | + using ElementComputeEpilogue = ElementAccumulator; |
| 41 | + using ElementInputA = ElementDequant; |
| 42 | + using ElementOutput = ElementDequant; |
| 43 | + |
| 44 | + using ElementW = uint8_t; // <- Weight is int4, uint8 for two of them |
| 45 | + |
| 46 | + // We pack 4 weights into one 16b element, so as to leverage cutlass tile iterators |
| 47 | + // for async shared memory loading and minimize bank conflict |
| 48 | + using ElementWPack = ElementDequant; |
| 49 | + |
| 50 | + using ElementQScale = ElementDequant; // <- data type of quantization scale |
| 51 | + using ElementQOffset = uint8_t; |
| 52 | + |
| 53 | + using LayoutInputA = cutlass::layout::RowMajor; |
| 54 | + using LayoutInputWPack = cutlass::layout::ColumnMajor; |
| 55 | + using LayoutOutput = cutlass::layout::RowMajor; |
| 56 | + |
| 57 | + // Layout of quantization scale and offset, oriented to be loaded using less instructions |
| 58 | + // in a warp tile |
| 59 | + using LayoutInputQScale = |
| 60 | + typename std::conditional<QuantBlocking::kRow == 1, |
| 61 | + cutlass::layout::ColumnMajor, |
| 62 | + cutlass::layout::RowMajor>::type; // <- layout of quantization scale |
| 63 | + |
| 64 | + using ShapeMMAThreadBlock = |
| 65 | + typename std::conditional<SmallM, |
| 66 | + cutlass::gemm::GemmShape<16, 64, 64>, |
| 67 | + cutlass::gemm::GemmShape<128, 256, 64>>::type; |
| 68 | + |
| 69 | + static constexpr int MinN = QuantBlocking::kColumn > 32 ? QuantBlocking::kColumn : 32; |
| 70 | + using ShapeMMAWarp = |
| 71 | + typename std::conditional<SmallM, |
| 72 | + cutlass::gemm::GemmShape<16, MinN, 64>, |
| 73 | + cutlass::gemm::GemmShape<64, 64, 64>>::type; |
| 74 | + |
| 75 | + using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>; |
| 76 | + |
| 77 | + // This code section describes how threadblocks are scheduled on GPU |
| 78 | + using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? |
| 79 | + |
| 80 | + // This code section describes the epilogue part of the kernel |
| 81 | + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< |
| 82 | + ElementOutput, // <- data type of output matrix |
| 83 | + 128 / cutlass::sizeof_bits<ElementOutput>::value, // <- the number of elements per vectorized |
| 84 | + // memory access. For a byte, it's 16 |
| 85 | + // elements. This becomes the vector width of |
| 86 | + // math instructions in the epilogue too |
| 87 | + ElementAccumulator, // <- data type of accumulator |
| 88 | + ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function |
| 89 | + |
| 90 | + // Number of pipelines you want to use |
| 91 | + static constexpr int NumStages = 3; |
| 92 | + |
| 93 | + using Gemm = cutlass::gemm::device::QuantBGemm< |
| 94 | + ElementInputA, |
| 95 | + LayoutInputA, |
| 96 | + ElementWPack, |
| 97 | + LayoutInputWPack, |
| 98 | + ElementQScale, |
| 99 | + typename std::conditional<kHasQuantOffset, ElementQOffset, std::monostate>::type, |
| 100 | + LayoutInputQScale, |
| 101 | + QuantBlocking, |
| 102 | + ElementOutput, |
| 103 | + LayoutOutput, |
| 104 | + ElementAccumulator, |
| 105 | + cutlass::arch::OpClassTensorOp, |
| 106 | + cutlass::arch::Sm80, |
| 107 | + ShapeMMAThreadBlock, |
| 108 | + ShapeMMAWarp, |
| 109 | + ShapeMMAOp, |
| 110 | + EpilogueOp, |
| 111 | + SwizzleThreadBlock, |
| 112 | + NumStages>; |
| 113 | + |
| 114 | + using Arguments = typename Gemm::Arguments; |
| 115 | + |
| 116 | + // Invoke gemm kernel (the version with quantization offset) |
| 117 | + static cutlass::Status run( |
| 118 | + cudaStream_t stream, |
| 119 | + const cutlass::gemm::GemmCoord& problem_size_, |
| 120 | + cutlass::TensorRef<ElementInputA const, LayoutInputA> ref_A_, |
| 121 | + cutlass::TensorRef<ElementWPack const, LayoutInputWPack> ref_B_, |
| 122 | + cutlass::TensorRef<ElementQScale const, LayoutInputQScale> ref_Qscale_, |
| 123 | + cutlass::TensorRef<ElementQOffset const, LayoutInputQScale> ref_Qoffset_, |
| 124 | + cutlass::TensorRef<ElementOutput const, LayoutOutput> ref_C_, |
| 125 | + cutlass::TensorRef<ElementOutput, LayoutOutput> ref_D_, |
| 126 | + typename EpilogueOp::Params epilogue_ = typename EpilogueOp::Params()) { |
| 127 | + if constexpr (!kHasQuantOffset) { |
| 128 | + return cutlass::Status::kErrorNotSupported; |
| 129 | + } else { |
| 130 | + if constexpr (ShapeMMAThreadBlock::kM == 16) { |
| 131 | + if (problem_size_.m() > 16) { |
| 132 | + // For M > 16, the caller should have picked the |
| 133 | + // kernel with bigger M |
| 134 | + return cutlass::Status::kErrorNotSupported; |
| 135 | + } |
| 136 | + } |
| 137 | + |
| 138 | + // Construct Gemm arguments |
| 139 | + Arguments args{ |
| 140 | + problem_size_, |
| 141 | + ref_A_, |
| 142 | + ref_B_, |
| 143 | + ref_Qscale_, |
| 144 | + ref_Qoffset_, |
| 145 | + ref_C_, |
| 146 | + ref_D_, |
| 147 | + epilogue_}; |
| 148 | + |
| 149 | + Gemm gemm_op; |
| 150 | + |
| 151 | + // Check if this GEMM can be run or not |
| 152 | + cutlass::Status status = gemm_op.can_implement(args); |
| 153 | + if (status != cutlass::Status::kSuccess) { |
| 154 | + return status; |
| 155 | + } |
| 156 | + |
| 157 | + // Launch the CUTLASS GEMM kernel. |
| 158 | + return gemm_op(args, nullptr, stream); |
| 159 | + } |
| 160 | + } |
| 161 | + |
| 162 | + // Invoke gemm kernel (the version without quantization offset) |
| 163 | + static cutlass::Status run( |
| 164 | + cudaStream_t stream, |
| 165 | + const cutlass::gemm::GemmCoord& problem_size_, |
| 166 | + cutlass::TensorRef<ElementInputA const, LayoutInputA> ref_A_, |
| 167 | + cutlass::TensorRef<ElementWPack const, LayoutInputWPack> ref_B_, |
| 168 | + cutlass::TensorRef<ElementQScale const, LayoutInputQScale> ref_Qscale_, |
| 169 | + cutlass::TensorRef<ElementOutput const, LayoutOutput> ref_C_, |
| 170 | + cutlass::TensorRef<ElementOutput, LayoutOutput> ref_D_, |
| 171 | + typename EpilogueOp::Params epilogue_ = typename EpilogueOp::Params()) { |
| 172 | + if constexpr (kHasQuantOffset) { |
| 173 | + return cutlass::Status::kErrorNotSupported; |
| 174 | + } else { |
| 175 | + if constexpr (ShapeMMAThreadBlock::kM == 16) { |
| 176 | + if (problem_size_.m() > 16) { |
| 177 | + // For M > 16, the caller should have picked the |
| 178 | + // kernel with bigger M |
| 179 | + return cutlass::Status::kErrorNotSupported; |
| 180 | + } |
| 181 | + } |
| 182 | + |
| 183 | + // Construct Gemm arguments |
| 184 | + Arguments args{ |
| 185 | + problem_size_, |
| 186 | + ref_A_, |
| 187 | + ref_B_, |
| 188 | + ref_Qscale_, |
| 189 | + ref_C_, |
| 190 | + ref_D_, |
| 191 | + epilogue_}; |
| 192 | + |
| 193 | + Gemm gemm_op; |
| 194 | + |
| 195 | + // Check if this GEMM can be run or not |
| 196 | + cutlass::Status status = gemm_op.can_implement(args); |
| 197 | + if (status != cutlass::Status::kSuccess) { |
| 198 | + return status; |
| 199 | + } |
| 200 | + |
| 201 | + // Launch the CUTLASS GEMM kernel. |
| 202 | + return gemm_op(args, nullptr, stream); |
| 203 | + } |
| 204 | + } |
| 205 | +}; |
| 206 | + |
| 207 | +} // namespace cuda |
| 208 | +} // namespace onnxruntime |
0 commit comments