diff --git a/CMakeLists.txt b/CMakeLists.txt index edc64f87730..5cc4e577ab4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -245,6 +245,7 @@ set(VLLM_EXT_SRC "csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/fp8/common.cu" + "csrc/quantization/fp8/per_token_group_quant.cu" "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" "csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/activation_kernels.cu" diff --git a/csrc/ops.h b/csrc/ops.h index 7f3e6b6923a..fdd3071c56e 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -297,6 +297,11 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scales, std::optional const& azp); +void per_token_group_quant_fp8(const torch::Tensor& input, + torch::Tensor& output_q, torch::Tensor& output_s, + int64_t group_size, double eps, double fp8_min, + double fp8_max, bool scale_ue8m0); + torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, diff --git a/csrc/quantization/fp8/per_token_group_quant.cu b/csrc/quantization/fp8/per_token_group_quant.cu new file mode 100644 index 00000000000..afc41faeca9 --- /dev/null +++ b/csrc/quantization/fp8/per_token_group_quant.cu @@ -0,0 +1,213 @@ +#include +#include + +#include + +#include +#include + +#include + +#include "../vectorization.cuh" +#include "../vectorization_utils.cuh" +#include "../../dispatch_utils.h" + +__device__ __forceinline__ float GroupReduceMax(float val, const int tid) { + unsigned mask = 0xffff; + + val = fmaxf(val, __shfl_xor_sync(mask, val, 8)); + val = fmaxf(val, __shfl_xor_sync(mask, val, 4)); + val = fmaxf(val, __shfl_xor_sync(mask, val, 2)); + val = fmaxf(val, __shfl_xor_sync(mask, val, 1)); + return val; +} + +template +__global__ void per_token_group_quant_8bit_kernel( + const T* __restrict__ input, void* __restrict__ output_q, + scale_packed_t* __restrict__ output_s, const int group_size, + const int num_groups, const int groups_per_block, const float eps, + const float min_8bit, const float max_8bit, const int scale_num_rows = 0, + const int scale_stride = 0) { + const int threads_per_group = 16; + const int64_t local_group_id = threadIdx.x / threads_per_group; + const int lane_id = threadIdx.x % threads_per_group; + + const int64_t block_group_id = blockIdx.x * groups_per_block; + const int64_t global_group_id = block_group_id + local_group_id; + const int64_t block_group_offset = global_group_id * group_size; + + float local_absmax = eps; + + using scale_element_t = float; + static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0); + + const T* group_input = input + block_group_offset; + DST_DTYPE* group_output = + static_cast(output_q) + block_group_offset; + scale_element_t* scale_output; + + if constexpr (IS_COLUMN_MAJOR) { + const int num_elems_per_pack = + static_cast(sizeof(scale_packed_t) / sizeof(scale_element_t)); + const int scale_num_rows_element = scale_num_rows * num_elems_per_pack; + const int row_idx = global_group_id / scale_num_rows_element; + const int col_idx_raw = global_group_id % scale_num_rows_element; + const int col_idx = col_idx_raw / num_elems_per_pack; + const int pack_idx = col_idx_raw % num_elems_per_pack; + scale_output = reinterpret_cast(output_s) + + (col_idx * scale_stride * num_elems_per_pack + + row_idx * num_elems_per_pack + pack_idx); + } else { + scale_output = output_s + global_group_id; + } + + // shared memory to cache each group's data to avoid double DRAM reads. + extern __shared__ __align__(16) char smem_raw[]; + T* smem = reinterpret_cast(smem_raw); + T* smem_group = smem + local_group_id * group_size; + + constexpr int vec_size = 16 / sizeof(T); + using vec_t = vllm::vec_n_t; + + // copy global -> shared & compute absmax + auto scalar_op_cache = [&] __device__(T & dst, const T& src) { + float abs_v = fabsf(static_cast(src)); + local_absmax = fmaxf(local_absmax, abs_v); + dst = src; + }; + + vllm::vectorize_with_alignment( + group_input, // in + smem_group, // out (shared) + group_size, // elements per group + lane_id, // thread id + threads_per_group, // stride in group + scalar_op_cache); // scalar handler + + local_absmax = GroupReduceMax(local_absmax, lane_id); + + float y_s = local_absmax / max_8bit; + if constexpr (SCALE_UE8M0) { + y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f)))); + } + + scale_element_t y_s_quant = y_s; + + if (lane_id == 0) { + *scale_output = y_s_quant; + } + + __syncthreads(); + + // quantize shared -> global 8-bit + auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) { + float q = fminf(fmaxf(static_cast(src) / y_s, min_8bit), max_8bit); + dst = DST_DTYPE(q); + }; + + vllm::vectorize_with_alignment( + smem_group, // in (shared) + group_output, // out (global quant tensor) + group_size, // elements + lane_id, // tid + threads_per_group, // stride + scalar_op_quant); // scalar handler +} + +void per_token_group_quant_8bit(const torch::Tensor& input, + torch::Tensor& output_q, + torch::Tensor& output_s, int64_t group_size, + double eps, double min_8bit, double max_8bit, + bool scale_ue8m0 = false) { + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(output_q.is_contiguous()); + + const int num_groups = input.numel() / group_size; + + TORCH_CHECK(input.numel() % group_size == 0); + TORCH_CHECK(output_s.dim() == 2); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + constexpr int THREADS_PER_GROUP = 16; + + int groups_per_block = 1; + + if (num_groups % 16 == 0) { + groups_per_block = 16; + } else if (num_groups % 8 == 0) { + groups_per_block = 8; + } else if (num_groups % 4 == 0) { + groups_per_block = 4; + } else if (num_groups % 2 == 0) { + groups_per_block = 2; + } + + auto dst_type = output_q.scalar_type(); + const int num_blocks = num_groups / groups_per_block; + const int num_threads = groups_per_block * THREADS_PER_GROUP; + + const bool is_column_major = output_s.stride(0) < output_s.stride(1); + const int scale_num_rows = output_s.size(1); + const int scale_stride = output_s.stride(1); + +#define LAUNCH_KERNEL(T, DST_DTYPE) \ + do { \ + dim3 grid(num_blocks); \ + dim3 block(num_threads); \ + size_t smem_bytes = \ + static_cast(groups_per_block) * group_size * sizeof(T); \ + if (is_column_major) { \ + if (scale_ue8m0) { \ + per_token_group_quant_8bit_kernel \ + <<>>( \ + static_cast(input.data_ptr()), output_q.data_ptr(), \ + static_cast(output_s.data_ptr()), group_size, \ + num_groups, groups_per_block, (float)eps, (float)min_8bit, \ + (float)max_8bit, scale_num_rows, scale_stride); \ + } else { \ + per_token_group_quant_8bit_kernel \ + <<>>( \ + static_cast(input.data_ptr()), output_q.data_ptr(), \ + static_cast(output_s.data_ptr()), group_size, \ + num_groups, groups_per_block, (float)eps, (float)min_8bit, \ + (float)max_8bit, scale_num_rows, scale_stride); \ + } \ + } else { \ + if (scale_ue8m0) { \ + per_token_group_quant_8bit_kernel \ + <<>>( \ + static_cast(input.data_ptr()), output_q.data_ptr(), \ + static_cast(output_s.data_ptr()), group_size, \ + num_groups, groups_per_block, (float)eps, (float)min_8bit, \ + (float)max_8bit); \ + } else { \ + per_token_group_quant_8bit_kernel \ + <<>>( \ + static_cast(input.data_ptr()), output_q.data_ptr(), \ + static_cast(output_s.data_ptr()), group_size, \ + num_groups, groups_per_block, (float)eps, (float)min_8bit, \ + (float)max_8bit); \ + } \ + } \ + } while (0) + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "per_token_group_quant_8bit", ([&] { + if (dst_type == at::ScalarType::Float8_e4m3fn) { + LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn); + } + })); + +#undef LAUNCH_KERNEL +} + +void per_token_group_quant_fp8(const torch::Tensor& input, + torch::Tensor& output_q, torch::Tensor& output_s, + int64_t group_size, double eps, double fp8_min, + double fp8_max, bool scale_ue8m0) { + per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, + fp8_min, fp8_max, scale_ue8m0); +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 79e2575974b..d310211afe4 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -601,6 +601,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, &dynamic_scaled_int8_quant); + // Compute per-token-group FP8 quantized tensor and scaling factor. + ops.def( + "per_token_group_fp8_quant(Tensor input, Tensor! output_q, Tensor! " + "output_s, " + "int group_size, float eps, float fp8_min, float fp8_max, bool " + "scale_ue8m0) -> ()"); + ops.impl("per_token_group_fp8_quant", torch::kCUDA, + &per_token_group_quant_fp8); + // Mamba selective scan kernel ops.def( "selective_scan_fwd(Tensor! u, Tensor! delta," diff --git a/tests/kernels/quantization/test_per_token_group_quant.py b/tests/kernels/quantization/test_per_token_group_quant.py new file mode 100644 index 00000000000..f826983fe94 --- /dev/null +++ b/tests/kernels/quantization/test_per_token_group_quant.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from unittest.mock import patch + +import pytest +import torch + +from vllm.model_executor.layers.quantization.utils import fp8_utils + + +@pytest.mark.parametrize("shape", [(32, 128), (64, 256), (16, 512)]) +@pytest.mark.parametrize("column_major", [False, True]) +@pytest.mark.parametrize("scale_ue8m0", [False, True]) +@pytest.mark.parametrize("group_size", [64, 128]) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_per_token_group_quant_fp8(shape, column_major: bool, + scale_ue8m0: bool, group_size: int): + device = "cuda" + + torch.manual_seed(42) + num_tokens, hidden_dim = shape + + x = (torch.randn( + (num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8) + + # cuda path + out_q, scale = fp8_utils.per_token_group_quant_fp8( + x, + group_size, + column_major_scales=column_major, + use_ue8m0=scale_ue8m0, + ) + + # triton ref + with patch("vllm.platforms.current_platform.is_cuda", return_value=False): + ref_q, ref_s = fp8_utils.per_token_group_quant_fp8( + x, + group_size, + column_major_scales=column_major, + use_ue8m0=scale_ue8m0, + ) + + assert torch.allclose(out_q.float(), ref_q.float(), atol=0.15, rtol=0.15) + assert torch.allclose(scale, ref_s, atol=0.01, rtol=0.01) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 20e7b444856..ee5f2b51564 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -366,6 +366,7 @@ def per_token_group_quant_fp8( dtype: Optional[torch.dtype] = None, column_major_scales: bool = False, out_q: Optional[torch.Tensor] = None, + use_ue8m0: bool = is_blackwell_deep_gemm_used(), ) -> tuple[torch.Tensor, torch.Tensor]: """Function to perform per-token-group quantization on an input tensor `x`. It converts the tensor values into signed float8 values and returns the @@ -397,8 +398,7 @@ def per_token_group_quant_fp8( if x_q is None: x_q = torch.empty_like(x, device=x.device, dtype=dtype) - M = x.numel() // group_size - N = group_size + # Allocate the scale tensor in either row- or column-major format. if column_major_scales: shape = (x.shape[-1] // group_size, ) + x.shape[:-1] x_s = torch.empty(shape, device=x.device, @@ -407,6 +407,15 @@ def per_token_group_quant_fp8( shape = x.shape[:-1] + (x.shape[-1] // group_size, ) x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + # prefer CUDA kernel if available + if current_platform.is_cuda() and x.is_contiguous(): + torch.ops._C.per_token_group_fp8_quant(x, x_q, x_s, group_size, eps, + fp8_min, fp8_max, use_ue8m0) + return x_q, x_s + + # TRITON FALLBACK + M = x.numel() // group_size + N = group_size BLOCK = triton.next_power_of_2(N) # heuristics for number of warps num_warps = min(max(BLOCK // 256, 1), 8) @@ -423,7 +432,7 @@ def per_token_group_quant_fp8( eps, fp8_min=fp8_min, fp8_max=fp8_max, - use_ue8m0=is_blackwell_deep_gemm_used(), + use_ue8m0=use_ue8m0, BLOCK=BLOCK, num_warps=num_warps, num_stages=num_stages, @@ -439,7 +448,7 @@ def per_token_group_quant_fp8( eps, fp8_min=fp8_min, fp8_max=fp8_max, - use_ue8m0=is_blackwell_deep_gemm_used(), + use_ue8m0=use_ue8m0, BLOCK=BLOCK, num_warps=num_warps, num_stages=num_stages,