Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fp8/per_token_group_quant.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/quantization/activation_kernels.cu"
Expand Down
5 changes: 5 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,11 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scales,
std::optional<torch::Tensor> const& azp);

void per_token_group_quant_fp8(const torch::Tensor& input,
torch::Tensor& output_q, torch::Tensor& output_s,
int64_t group_size, double eps, double fp8_min,
double fp8_max, bool scale_ue8m0);

torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
Expand Down
213 changes: 213 additions & 0 deletions csrc/quantization/fp8/per_token_group_quant.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/util/Float8_e4m3fn.h>

#include <cmath>

#include <cuda_fp16.h>
#include <cuda_bf16.h>

#include <torch/all.h>

#include "../vectorization.cuh"
#include "../vectorization_utils.cuh"
#include "../../dispatch_utils.h"

__device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
unsigned mask = 0xffff;

val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
val = fmaxf(val, __shfl_xor_sync(mask, val, 4));
val = fmaxf(val, __shfl_xor_sync(mask, val, 2));
val = fmaxf(val, __shfl_xor_sync(mask, val, 1));
return val;
}

template <typename T, typename DST_DTYPE, bool IS_COLUMN_MAJOR = false,
bool SCALE_UE8M0 = false, typename scale_packed_t = float>
__global__ void per_token_group_quant_8bit_kernel(
const T* __restrict__ input, void* __restrict__ output_q,
scale_packed_t* __restrict__ output_s, const int group_size,
const int num_groups, const int groups_per_block, const float eps,
const float min_8bit, const float max_8bit, const int scale_num_rows = 0,
const int scale_stride = 0) {
const int threads_per_group = 16;
const int64_t local_group_id = threadIdx.x / threads_per_group;
const int lane_id = threadIdx.x % threads_per_group;

const int64_t block_group_id = blockIdx.x * groups_per_block;
const int64_t global_group_id = block_group_id + local_group_id;
const int64_t block_group_offset = global_group_id * group_size;

float local_absmax = eps;

using scale_element_t = float;
static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0);

const T* group_input = input + block_group_offset;
DST_DTYPE* group_output =
static_cast<DST_DTYPE*>(output_q) + block_group_offset;
scale_element_t* scale_output;

if constexpr (IS_COLUMN_MAJOR) {
const int num_elems_per_pack =
static_cast<int>(sizeof(scale_packed_t) / sizeof(scale_element_t));
const int scale_num_rows_element = scale_num_rows * num_elems_per_pack;
const int row_idx = global_group_id / scale_num_rows_element;
const int col_idx_raw = global_group_id % scale_num_rows_element;
const int col_idx = col_idx_raw / num_elems_per_pack;
const int pack_idx = col_idx_raw % num_elems_per_pack;
scale_output = reinterpret_cast<scale_element_t*>(output_s) +
(col_idx * scale_stride * num_elems_per_pack +
row_idx * num_elems_per_pack + pack_idx);
} else {
scale_output = output_s + global_group_id;
}

// shared memory to cache each group's data to avoid double DRAM reads.
extern __shared__ __align__(16) char smem_raw[];
T* smem = reinterpret_cast<T*>(smem_raw);
T* smem_group = smem + local_group_id * group_size;

constexpr int vec_size = 16 / sizeof(T);
using vec_t = vllm::vec_n_t<T, vec_size>;

// copy global -> shared & compute absmax
auto scalar_op_cache = [&] __device__(T & dst, const T& src) {
float abs_v = fabsf(static_cast<float>(src));
local_absmax = fmaxf(local_absmax, abs_v);
dst = src;
};

vllm::vectorize_with_alignment<vec_size>(
group_input, // in
smem_group, // out (shared)
group_size, // elements per group
lane_id, // thread id
threads_per_group, // stride in group
scalar_op_cache); // scalar handler

local_absmax = GroupReduceMax(local_absmax, lane_id);

float y_s = local_absmax / max_8bit;
if constexpr (SCALE_UE8M0) {
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
}

scale_element_t y_s_quant = y_s;

if (lane_id == 0) {
*scale_output = y_s_quant;
}

__syncthreads();

// quantize shared -> global 8-bit
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) {
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit);
dst = DST_DTYPE(q);
};

vllm::vectorize_with_alignment<vec_size>(
smem_group, // in (shared)
group_output, // out (global quant tensor)
group_size, // elements
lane_id, // tid
threads_per_group, // stride
scalar_op_quant); // scalar handler
}

void per_token_group_quant_8bit(const torch::Tensor& input,
torch::Tensor& output_q,
torch::Tensor& output_s, int64_t group_size,
double eps, double min_8bit, double max_8bit,
bool scale_ue8m0 = false) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(output_q.is_contiguous());
Comment on lines +124 to +125
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Contiguous might be a problem for MLA, so please test a couple DeepSeek evals/benchmarks

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, so I choose to fallback to triton when input is not contiguous.
Now it works:

VLLM_USE_DEEP_GEMM=1 lm_eval   --model vllm   --model_args "pretrained=deepseek-ai/DeepSeek-R1,data_parallel_size=8,gpu_memory_utilization=0.95,max_model_len=16384,enable_expert_parallel=True"   --tasks gsm8k   --batch_size auto   --num_fewshot 5
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9560|±  |0.0056|
|     |       |strict-match    |     5|exact_match|↑  |0.9545|±  |0.0057|


const int num_groups = input.numel() / group_size;

TORCH_CHECK(input.numel() % group_size == 0);
TORCH_CHECK(output_s.dim() == 2);

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

constexpr int THREADS_PER_GROUP = 16;

int groups_per_block = 1;

if (num_groups % 16 == 0) {
groups_per_block = 16;
} else if (num_groups % 8 == 0) {
groups_per_block = 8;
} else if (num_groups % 4 == 0) {
groups_per_block = 4;
} else if (num_groups % 2 == 0) {
groups_per_block = 2;
}

auto dst_type = output_q.scalar_type();
const int num_blocks = num_groups / groups_per_block;
const int num_threads = groups_per_block * THREADS_PER_GROUP;

const bool is_column_major = output_s.stride(0) < output_s.stride(1);
const int scale_num_rows = output_s.size(1);
const int scale_stride = output_s.stride(1);

#define LAUNCH_KERNEL(T, DST_DTYPE) \
do { \
dim3 grid(num_blocks); \
dim3 block(num_threads); \
size_t smem_bytes = \
static_cast<size_t>(groups_per_block) * group_size * sizeof(T); \
if (is_column_major) { \
if (scale_ue8m0) { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, true> \
<<<grid, block, smem_bytes, stream>>>( \
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), group_size, \
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
(float)max_8bit, scale_num_rows, scale_stride); \
} else { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, false> \
<<<grid, block, smem_bytes, stream>>>( \
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), group_size, \
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
(float)max_8bit, scale_num_rows, scale_stride); \
} \
} else { \
if (scale_ue8m0) { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false, true> \
<<<grid, block, smem_bytes, stream>>>( \
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), group_size, \
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
(float)max_8bit); \
} else { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false, false> \
<<<grid, block, smem_bytes, stream>>>( \
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), group_size, \
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
(float)max_8bit); \
} \
} \
} while (0)

VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "per_token_group_quant_8bit", ([&] {
if (dst_type == at::ScalarType::Float8_e4m3fn) {
LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn);
}
}));

#undef LAUNCH_KERNEL
}

void per_token_group_quant_fp8(const torch::Tensor& input,
torch::Tensor& output_q, torch::Tensor& output_s,
int64_t group_size, double eps, double fp8_min,
double fp8_max, bool scale_ue8m0) {
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
fp8_min, fp8_max, scale_ue8m0);
}
9 changes: 9 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
&dynamic_scaled_int8_quant);

// Compute per-token-group FP8 quantized tensor and scaling factor.
ops.def(
"per_token_group_fp8_quant(Tensor input, Tensor! output_q, Tensor! "
"output_s, "
"int group_size, float eps, float fp8_min, float fp8_max, bool "
"scale_ue8m0) -> ()");
ops.impl("per_token_group_fp8_quant", torch::kCUDA,
&per_token_group_quant_fp8);

// Mamba selective scan kernel
ops.def(
"selective_scan_fwd(Tensor! u, Tensor! delta,"
Expand Down
44 changes: 44 additions & 0 deletions tests/kernels/quantization/test_per_token_group_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import patch

import pytest
import torch

from vllm.model_executor.layers.quantization.utils import fp8_utils


@pytest.mark.parametrize("shape", [(32, 128), (64, 256), (16, 512)])
@pytest.mark.parametrize("column_major", [False, True])
@pytest.mark.parametrize("scale_ue8m0", [False, True])
@pytest.mark.parametrize("group_size", [64, 128])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_per_token_group_quant_fp8(shape, column_major: bool,
scale_ue8m0: bool, group_size: int):
device = "cuda"

torch.manual_seed(42)
num_tokens, hidden_dim = shape

x = (torch.randn(
(num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8)

# cuda path
out_q, scale = fp8_utils.per_token_group_quant_fp8(
x,
group_size,
column_major_scales=column_major,
use_ue8m0=scale_ue8m0,
)

# triton ref
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
ref_q, ref_s = fp8_utils.per_token_group_quant_fp8(
x,
group_size,
column_major_scales=column_major,
use_ue8m0=scale_ue8m0,
)

assert torch.allclose(out_q.float(), ref_q.float(), atol=0.15, rtol=0.15)
assert torch.allclose(scale, ref_s, atol=0.01, rtol=0.01)
17 changes: 13 additions & 4 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def per_token_group_quant_fp8(
dtype: Optional[torch.dtype] = None,
column_major_scales: bool = False,
out_q: Optional[torch.Tensor] = None,
use_ue8m0: bool = is_blackwell_deep_gemm_used(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worry about setting this as a default variable since this function could be used on Blackwell, but for the CUTLASS or FlashInfer FP8 block kernels that are now on SM100

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_blackwell_deep_gemm_used will check the env VLLM_USE_DEEP_GEMM as well, so it won't cause trouble now.
And this default is no so good as well, since DeepGemm now supports float32 scale on B200 now, I will have another pr for this, letting the user decide whether to use e8m0

) -> tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
Expand Down Expand Up @@ -397,8 +398,7 @@ def per_token_group_quant_fp8(
if x_q is None:
x_q = torch.empty_like(x, device=x.device, dtype=dtype)

M = x.numel() // group_size
N = group_size
# Allocate the scale tensor in either row- or column-major format.
if column_major_scales:
shape = (x.shape[-1] // group_size, ) + x.shape[:-1]
x_s = torch.empty(shape, device=x.device,
Expand All @@ -407,6 +407,15 @@ def per_token_group_quant_fp8(
shape = x.shape[:-1] + (x.shape[-1] // group_size, )
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)

# prefer CUDA kernel if available
if current_platform.is_cuda() and x.is_contiguous():
torch.ops._C.per_token_group_fp8_quant(x, x_q, x_s, group_size, eps,
fp8_min, fp8_max, use_ue8m0)
return x_q, x_s

# TRITON FALLBACK
M = x.numel() // group_size
N = group_size
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
Expand All @@ -423,7 +432,7 @@ def per_token_group_quant_fp8(
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
use_ue8m0=is_blackwell_deep_gemm_used(),
use_ue8m0=use_ue8m0,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
Expand All @@ -439,7 +448,7 @@ def per_token_group_quant_fp8(
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
use_ue8m0=is_blackwell_deep_gemm_used(),
use_ue8m0=use_ue8m0,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
Expand Down