-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[Perf] Cuda Kernel for Per Token Group Quant #21083
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6f21306
da50a26
9f59908
1b3c132
bc60047
403d8d4
4c74c74
0ffbe1f
119e970
de2606e
bf73c8b
aee0253
d503466
952560a
c2eb372
6a5d68a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <c10/util/Float8_e4m3fn.h> | ||
|
||
#include <cmath> | ||
|
||
#include <cuda_fp16.h> | ||
#include <cuda_bf16.h> | ||
|
||
#include <torch/all.h> | ||
|
||
#include "../vectorization.cuh" | ||
#include "../vectorization_utils.cuh" | ||
#include "../../dispatch_utils.h" | ||
|
||
__device__ __forceinline__ float GroupReduceMax(float val, const int tid) { | ||
unsigned mask = 0xffff; | ||
|
||
val = fmaxf(val, __shfl_xor_sync(mask, val, 8)); | ||
val = fmaxf(val, __shfl_xor_sync(mask, val, 4)); | ||
val = fmaxf(val, __shfl_xor_sync(mask, val, 2)); | ||
val = fmaxf(val, __shfl_xor_sync(mask, val, 1)); | ||
return val; | ||
} | ||
|
||
template <typename T, typename DST_DTYPE, bool IS_COLUMN_MAJOR = false, | ||
bool SCALE_UE8M0 = false, typename scale_packed_t = float> | ||
__global__ void per_token_group_quant_8bit_kernel( | ||
const T* __restrict__ input, void* __restrict__ output_q, | ||
scale_packed_t* __restrict__ output_s, const int group_size, | ||
const int num_groups, const int groups_per_block, const float eps, | ||
const float min_8bit, const float max_8bit, const int scale_num_rows = 0, | ||
const int scale_stride = 0) { | ||
const int threads_per_group = 16; | ||
const int64_t local_group_id = threadIdx.x / threads_per_group; | ||
const int lane_id = threadIdx.x % threads_per_group; | ||
|
||
const int64_t block_group_id = blockIdx.x * groups_per_block; | ||
const int64_t global_group_id = block_group_id + local_group_id; | ||
const int64_t block_group_offset = global_group_id * group_size; | ||
|
||
float local_absmax = eps; | ||
|
||
using scale_element_t = float; | ||
static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0); | ||
|
||
const T* group_input = input + block_group_offset; | ||
DST_DTYPE* group_output = | ||
static_cast<DST_DTYPE*>(output_q) + block_group_offset; | ||
scale_element_t* scale_output; | ||
|
||
if constexpr (IS_COLUMN_MAJOR) { | ||
const int num_elems_per_pack = | ||
static_cast<int>(sizeof(scale_packed_t) / sizeof(scale_element_t)); | ||
const int scale_num_rows_element = scale_num_rows * num_elems_per_pack; | ||
const int row_idx = global_group_id / scale_num_rows_element; | ||
const int col_idx_raw = global_group_id % scale_num_rows_element; | ||
const int col_idx = col_idx_raw / num_elems_per_pack; | ||
const int pack_idx = col_idx_raw % num_elems_per_pack; | ||
scale_output = reinterpret_cast<scale_element_t*>(output_s) + | ||
(col_idx * scale_stride * num_elems_per_pack + | ||
row_idx * num_elems_per_pack + pack_idx); | ||
} else { | ||
scale_output = output_s + global_group_id; | ||
} | ||
|
||
// shared memory to cache each group's data to avoid double DRAM reads. | ||
extern __shared__ __align__(16) char smem_raw[]; | ||
T* smem = reinterpret_cast<T*>(smem_raw); | ||
T* smem_group = smem + local_group_id * group_size; | ||
|
||
constexpr int vec_size = 16 / sizeof(T); | ||
using vec_t = vllm::vec_n_t<T, vec_size>; | ||
|
||
// copy global -> shared & compute absmax | ||
auto scalar_op_cache = [&] __device__(T & dst, const T& src) { | ||
float abs_v = fabsf(static_cast<float>(src)); | ||
local_absmax = fmaxf(local_absmax, abs_v); | ||
dst = src; | ||
}; | ||
|
||
vllm::vectorize_with_alignment<vec_size>( | ||
group_input, // in | ||
smem_group, // out (shared) | ||
group_size, // elements per group | ||
lane_id, // thread id | ||
threads_per_group, // stride in group | ||
scalar_op_cache); // scalar handler | ||
|
||
local_absmax = GroupReduceMax(local_absmax, lane_id); | ||
|
||
float y_s = local_absmax / max_8bit; | ||
if constexpr (SCALE_UE8M0) { | ||
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f)))); | ||
} | ||
|
||
scale_element_t y_s_quant = y_s; | ||
|
||
if (lane_id == 0) { | ||
*scale_output = y_s_quant; | ||
} | ||
|
||
__syncthreads(); | ||
|
||
// quantize shared -> global 8-bit | ||
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) { | ||
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit); | ||
dst = DST_DTYPE(q); | ||
}; | ||
|
||
vllm::vectorize_with_alignment<vec_size>( | ||
smem_group, // in (shared) | ||
group_output, // out (global quant tensor) | ||
group_size, // elements | ||
lane_id, // tid | ||
threads_per_group, // stride | ||
scalar_op_quant); // scalar handler | ||
} | ||
|
||
void per_token_group_quant_8bit(const torch::Tensor& input, | ||
torch::Tensor& output_q, | ||
torch::Tensor& output_s, int64_t group_size, | ||
double eps, double min_8bit, double max_8bit, | ||
bool scale_ue8m0 = false) { | ||
TORCH_CHECK(input.is_contiguous()); | ||
TORCH_CHECK(output_q.is_contiguous()); | ||
|
||
const int num_groups = input.numel() / group_size; | ||
|
||
TORCH_CHECK(input.numel() % group_size == 0); | ||
TORCH_CHECK(output_s.dim() == 2); | ||
yewentao256 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
|
||
constexpr int THREADS_PER_GROUP = 16; | ||
|
||
int groups_per_block = 1; | ||
|
||
if (num_groups % 16 == 0) { | ||
groups_per_block = 16; | ||
} else if (num_groups % 8 == 0) { | ||
groups_per_block = 8; | ||
} else if (num_groups % 4 == 0) { | ||
groups_per_block = 4; | ||
} else if (num_groups % 2 == 0) { | ||
groups_per_block = 2; | ||
} | ||
|
||
auto dst_type = output_q.scalar_type(); | ||
const int num_blocks = num_groups / groups_per_block; | ||
const int num_threads = groups_per_block * THREADS_PER_GROUP; | ||
|
||
const bool is_column_major = output_s.stride(0) < output_s.stride(1); | ||
const int scale_num_rows = output_s.size(1); | ||
const int scale_stride = output_s.stride(1); | ||
|
||
#define LAUNCH_KERNEL(T, DST_DTYPE) \ | ||
do { \ | ||
mgoin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dim3 grid(num_blocks); \ | ||
dim3 block(num_threads); \ | ||
size_t smem_bytes = \ | ||
static_cast<size_t>(groups_per_block) * group_size * sizeof(T); \ | ||
if (is_column_major) { \ | ||
if (scale_ue8m0) { \ | ||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, true> \ | ||
<<<grid, block, smem_bytes, stream>>>( \ | ||
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \ | ||
static_cast<float*>(output_s.data_ptr()), group_size, \ | ||
num_groups, groups_per_block, (float)eps, (float)min_8bit, \ | ||
(float)max_8bit, scale_num_rows, scale_stride); \ | ||
} else { \ | ||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, false> \ | ||
<<<grid, block, smem_bytes, stream>>>( \ | ||
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \ | ||
static_cast<float*>(output_s.data_ptr()), group_size, \ | ||
num_groups, groups_per_block, (float)eps, (float)min_8bit, \ | ||
(float)max_8bit, scale_num_rows, scale_stride); \ | ||
} \ | ||
} else { \ | ||
if (scale_ue8m0) { \ | ||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false, true> \ | ||
<<<grid, block, smem_bytes, stream>>>( \ | ||
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \ | ||
static_cast<float*>(output_s.data_ptr()), group_size, \ | ||
num_groups, groups_per_block, (float)eps, (float)min_8bit, \ | ||
(float)max_8bit); \ | ||
} else { \ | ||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false, false> \ | ||
<<<grid, block, smem_bytes, stream>>>( \ | ||
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \ | ||
static_cast<float*>(output_s.data_ptr()), group_size, \ | ||
num_groups, groups_per_block, (float)eps, (float)min_8bit, \ | ||
(float)max_8bit); \ | ||
} \ | ||
} \ | ||
} while (0) | ||
|
||
VLLM_DISPATCH_FLOATING_TYPES( | ||
input.scalar_type(), "per_token_group_quant_8bit", ([&] { | ||
if (dst_type == at::ScalarType::Float8_e4m3fn) { | ||
LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn); | ||
} | ||
})); | ||
|
||
#undef LAUNCH_KERNEL | ||
} | ||
|
||
void per_token_group_quant_fp8(const torch::Tensor& input, | ||
torch::Tensor& output_q, torch::Tensor& output_s, | ||
int64_t group_size, double eps, double fp8_min, | ||
double fp8_max, bool scale_ue8m0) { | ||
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, | ||
fp8_min, fp8_max, scale_ue8m0); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
import torch | ||
|
||
from vllm.model_executor.layers.quantization.utils import fp8_utils | ||
|
||
|
||
@pytest.mark.parametrize("shape", [(32, 128), (64, 256), (16, 512)]) | ||
@pytest.mark.parametrize("column_major", [False, True]) | ||
@pytest.mark.parametrize("scale_ue8m0", [False, True]) | ||
@pytest.mark.parametrize("group_size", [64, 128]) | ||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
def test_per_token_group_quant_fp8(shape, column_major: bool, | ||
scale_ue8m0: bool, group_size: int): | ||
device = "cuda" | ||
|
||
torch.manual_seed(42) | ||
num_tokens, hidden_dim = shape | ||
|
||
x = (torch.randn( | ||
(num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8) | ||
|
||
# cuda path | ||
out_q, scale = fp8_utils.per_token_group_quant_fp8( | ||
x, | ||
group_size, | ||
column_major_scales=column_major, | ||
use_ue8m0=scale_ue8m0, | ||
) | ||
|
||
# triton ref | ||
with patch("vllm.platforms.current_platform.is_cuda", return_value=False): | ||
ref_q, ref_s = fp8_utils.per_token_group_quant_fp8( | ||
x, | ||
group_size, | ||
column_major_scales=column_major, | ||
use_ue8m0=scale_ue8m0, | ||
) | ||
|
||
assert torch.allclose(out_q.float(), ref_q.float(), atol=0.15, rtol=0.15) | ||
assert torch.allclose(scale, ref_s, atol=0.01, rtol=0.01) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -366,6 +366,7 @@ def per_token_group_quant_fp8( | |
dtype: Optional[torch.dtype] = None, | ||
column_major_scales: bool = False, | ||
out_q: Optional[torch.Tensor] = None, | ||
use_ue8m0: bool = is_blackwell_deep_gemm_used(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I worry about setting this as a default variable since this function could be used on Blackwell, but for the CUTLASS or FlashInfer FP8 block kernels that are now on SM100 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
) -> tuple[torch.Tensor, torch.Tensor]: | ||
"""Function to perform per-token-group quantization on an input tensor `x`. | ||
It converts the tensor values into signed float8 values and returns the | ||
|
@@ -397,8 +398,7 @@ def per_token_group_quant_fp8( | |
if x_q is None: | ||
x_q = torch.empty_like(x, device=x.device, dtype=dtype) | ||
|
||
M = x.numel() // group_size | ||
N = group_size | ||
# Allocate the scale tensor in either row- or column-major format. | ||
if column_major_scales: | ||
shape = (x.shape[-1] // group_size, ) + x.shape[:-1] | ||
x_s = torch.empty(shape, device=x.device, | ||
|
@@ -407,6 +407,15 @@ def per_token_group_quant_fp8( | |
shape = x.shape[:-1] + (x.shape[-1] // group_size, ) | ||
x_s = torch.empty(shape, device=x.device, dtype=torch.float32) | ||
|
||
# prefer CUDA kernel if available | ||
if current_platform.is_cuda() and x.is_contiguous(): | ||
torch.ops._C.per_token_group_fp8_quant(x, x_q, x_s, group_size, eps, | ||
fp8_min, fp8_max, use_ue8m0) | ||
return x_q, x_s | ||
|
||
# TRITON FALLBACK | ||
M = x.numel() // group_size | ||
N = group_size | ||
BLOCK = triton.next_power_of_2(N) | ||
# heuristics for number of warps | ||
num_warps = min(max(BLOCK // 256, 1), 8) | ||
|
@@ -423,7 +432,7 @@ def per_token_group_quant_fp8( | |
eps, | ||
fp8_min=fp8_min, | ||
fp8_max=fp8_max, | ||
use_ue8m0=is_blackwell_deep_gemm_used(), | ||
use_ue8m0=use_ue8m0, | ||
BLOCK=BLOCK, | ||
num_warps=num_warps, | ||
num_stages=num_stages, | ||
|
@@ -439,7 +448,7 @@ def per_token_group_quant_fp8( | |
eps, | ||
fp8_min=fp8_min, | ||
fp8_max=fp8_max, | ||
use_ue8m0=is_blackwell_deep_gemm_used(), | ||
use_ue8m0=use_ue8m0, | ||
BLOCK=BLOCK, | ||
num_warps=num_warps, | ||
num_stages=num_stages, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Contiguous might be a problem for MLA, so please test a couple DeepSeek evals/benchmarks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right, so I choose to fallback to triton when input is not contiguous.
Now it works: