Skip to content
Merged
98 changes: 98 additions & 0 deletions benchmarks/kernels/bench_per_token_quant_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from typing import Callable

import torch

from vllm import _custom_ops as ops
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.triton_utils import triton


# TODO(luka): use standalone_compile utility
def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int):
def inner(*args):
torch._dynamo.mark_dynamic(args[arg_index], dim_index)
return fn(*args)

return inner


torch._dynamo.config.recompile_limit = 8888
compilation_config = CompilationConfig(custom_ops=["none"])
with set_current_vllm_config(VllmConfig(compilation_config=compilation_config)):
torch_per_token_quant_fp8 = torch.compile(
QuantFP8(False, GroupShape.PER_TOKEN),
fullgraph=True,
dynamic=False, # recompile for different shapes
)

# First dim is explicitly dynamic to simulate vLLM usage
torch_per_token_quant_fp8 = with_dyn_arg(torch_per_token_quant_fp8, 0, 0)


def cuda_per_token_quant_fp8(
input: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return ops.scaled_fp8_quant(input)


def calculate_diff(batch_size: int, seq_len: int):
"""Calculate difference between Triton and CUDA implementations."""
device = torch.device("cuda")
x = torch.rand((batch_size * seq_len, 4096), dtype=torch.float16, device=device)

torch_out, torch_scale = torch_per_token_quant_fp8(x)
cuda_out, cuda_scale = cuda_per_token_quant_fp8(x)

if torch.allclose(
cuda_out.to(torch.float32), torch_out.to(torch.float32), rtol=1e-3, atol=1e-5
) and torch.allclose(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5):
print("βœ… All implementations match")
else:
print("❌ Implementations differ")


batch_size_range = [1, 16, 32, 64, 128]
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]

configs = list(itertools.product(batch_size_range, seq_len_range))


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len"],
x_vals=configs,
line_arg="provider",
line_vals=["torch", "cuda"],
line_names=["Torch", "CUDA"],
styles=[("blue", "-"), ("green", "-")],
ylabel="us",
plot_name="per-token-dynamic-quant-fp8-performance",
args={},
)
)
def benchmark_quantization(batch_size, seq_len, provider):
dtype = torch.float16
device = torch.device("cuda")

x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype)

quantiles = [0.5, 0.2, 0.8]

if provider == "torch":
fn = lambda: torch_per_token_quant_fp8(x.clone())
elif provider == "cuda":
fn = lambda: cuda_per_token_quant_fp8(x.clone())

ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)

return 1000 * ms, 1000 * max_ms, 1000 * min_ms


if __name__ == "__main__":
calculate_diff(batch_size=4, seq_len=4096)
benchmark_quantization.run(print_data=True)
11 changes: 7 additions & 4 deletions tests/compile/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def __init__(self, hidden_size: int, eps: float, static: bool,
]
self.fp8_linear = Fp8LinearOp(
cutlass_fp8_supported=cutlass_fp8_enabled,
use_per_token_if_dynamic=True)
act_quant_static=static,
act_quant_group_shape=group_shape,
)

def forward(self, x):
resid = torch.sqrt(x)
Expand Down Expand Up @@ -91,9 +93,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths

vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
vllm_config.compilation_config.pass_config = \
PassConfig(enable_fusion=True, enable_noop=True)
level=CompilationLevel.PIECEWISE,
custom_ops=["+rms_norm", "+quant_fp8"],
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
))
with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
noop_pass = NoOpEliminationPass(vllm_config)
Expand Down
2 changes: 2 additions & 0 deletions tests/compile/test_fusion_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
# DYNAMO_ONCE does not properly propagate shapes.
level=CompilationLevel.DYNAMO_AS_IS,
backend="tests.compile.test_fusion_attn.backend_unfused",
custom_ops=["+quant_fp8"],
)
vllm_config = VllmConfig(compilation_config=compile_config)
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
Expand All @@ -73,6 +74,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
# DYNAMO_ONCE does not properly propagate shapes.
level=CompilationLevel.DYNAMO_AS_IS,
backend="tests.compile.test_fusion_attn.backend",
custom_ops=["+quant_fp8"],
)
vllm_config = VllmConfig(compilation_config=compile_config)

Expand Down
37 changes: 30 additions & 7 deletions tests/compile/test_silu_mul_quant_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,56 @@
import torch

import vllm.envs as envs
from vllm._custom_ops import scaled_fp8_quant
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_FP8_SUPPORTED, Fp8LinearOp)
from vllm.platforms import current_platform

from .backend import TestBackend


class TestModel(torch.nn.Module):

def __init__(self, *args, **kwargs):
def __init__(self, hidden_size: int, cutlass_fp8_enabled: bool, *args,
**kwargs):
super().__init__(*args, **kwargs)
self.silu_and_mul = SiluAndMul()
self.wscale = torch.rand(1, dtype=torch.float32)
self.scale = torch.rand(1, dtype=torch.float32)

self.w = (torch.rand(
hidden_size,
hidden_size).to(dtype=current_platform.fp8_dtype()).t())

self.fp8_linear = Fp8LinearOp(
cutlass_fp8_supported=cutlass_fp8_enabled,
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)

def forward(self, x):
y = self.silu_and_mul(x)
x2 = scaled_fp8_quant(y, self.scale)
x2 = self.fp8_linear.apply(y,
self.w,
self.wscale,
input_scale=self.wscale)
return x2


@pytest.mark.parametrize("num_tokens", [256])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("cutlass_fp8_enabled",
[True, False] if CUTLASS_FP8_SUPPORTED else [False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
reason="Only test on CUDA and ROCm")
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size,
cutlass_fp8_enabled):
torch.set_default_device("cuda")
torch.set_default_dtype(torch.float16)

Expand All @@ -40,11 +63,11 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
pass_config=PassConfig(enable_fusion=True, enable_noop=True))
fusion_pass = ActivationQuantFusionPass(config)

backend = TestBackend(fusion_pass)
model = TestModel()
backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
model = TestModel(hidden_size, cutlass_fp8_enabled)

# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
x = torch.rand(num_tokens, hidden_size * 2)
torch._dynamo.mark_dynamic(x, 0)

result = model(x)
Expand Down
6 changes: 4 additions & 2 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import torch

from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.multimodal import MultiModalPlaceholderMap

if TYPE_CHECKING:
Expand Down Expand Up @@ -289,7 +291,7 @@ def forward(
raise NotImplementedError

def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: tuple[int, int]):
group_shape: GroupShape):
"""
Does this attention implementation support fused output quantization.
This is used by the AttnFusionPass to only fuse output quantization
Expand All @@ -298,7 +300,7 @@ def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
TODO(luka) merge parameters into QuantDescriptor
:param dtype: quantized dtype
:param static: static or dynamic quantization
:param group_shape: quant group shape. (-1, -1) for per-tensor.
:param group_shape: quant group shape.
:return: is fusion supported for this type of quantization
"""
return False
Expand Down
6 changes: 4 additions & 2 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
PagedAttentionMetadata)
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.platforms import current_platform
from vllm.platforms.rocm import use_rocm_custom_paged_attention

Expand Down Expand Up @@ -598,10 +600,10 @@ def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
head_dim))

def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: tuple[int, int]):
group_shape: GroupShape):
if self.use_triton_flash_attn:
return dtype == current_platform.fp8_dtype(
) and static and group_shape == (-1, -1) # per-tensor
) and static and group_shape == GroupShape.PER_TENSOR

# Only supported in the Triton backend
return False
Expand Down
25 changes: 3 additions & 22 deletions vllm/compilation/fusion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, ClassVar, NamedTuple, Optional
from typing import Callable, NamedTuple, Optional

import torch
import torch._inductor.pattern_matcher as pm
Expand All @@ -11,6 +11,8 @@

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.platforms import current_platform

from .fx_utils import find_getitem_maybe
Expand All @@ -33,27 +35,6 @@ def empty_fp32(*args, **kwargs):
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default


# Use proxy as NamedTuple direct subclasses cannot have static members
class _GroupShape(NamedTuple):
row: int
col: int


class GroupShape(_GroupShape):
"""
This class describes the quantization group shape.
It includes static members for common shapes (per-tensor, per-token).
"""

# Aliases for common quantization group shapes
PER_TENSOR: ClassVar['GroupShape']
PER_TOKEN: ClassVar['GroupShape']


GroupShape.PER_TENSOR = GroupShape(-1, -1)
GroupShape.PER_TOKEN = GroupShape(1, -1)


class QuantKey(NamedTuple):
"""
Named tuple for identifying the type of quantization.
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/fused_moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def _fp8_quantize(
is provided, the output will be blocked.
"""
if block_shape is None:
# TODO(luka): use QuantFP8 custom op
# https://github.com/vllm-project/vllm/issues/20711
A, A_scale = ops.scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_act_token)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
QKVParallelLinear)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise, sparse_cutlass_supported)
from vllm.model_executor.parameter import (BasevLLMParameter,
Expand All @@ -24,6 +27,8 @@

__all__ = ["CompressedTensors24"]

from vllm.platforms import current_platform


class CompressedTensors24(CompressedTensorsScheme):

Expand All @@ -45,6 +50,12 @@ def __init__(
and self.model_compressor.sparsity_config.format
== CompressionFormat.sparse_24_bitmask.value)

if quantized and input_quant is not None and \
self._get_quant_dtype() == current_platform.fp8_dtype():
static = not input_quant.dynamic
g_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
self.quant_fp8 = QuantFP8(static, g_shape)

@classmethod
def get_min_capability(cls) -> int:
# Only cutlass 3.x kernels are implemented so far
Expand Down Expand Up @@ -232,21 +243,15 @@ def apply_weights(
:return: The output tensor of the layer
"""
if self.quantized:
scale = None
if hasattr(layer, "input_scale"):
scale = layer.input_scale
scale = getattr(layer, 'input_scale', None)

if self.weights_dtype == torch.int8:
ops_output = ops.scaled_int8_quant(x, scale=scale)
q_input = ops_output[0]
input_scale = ops_output[1]
else:
assert self.weights_dtype == torch.float8_e4m3fn
if scale is not None:
q_input, input_scale = ops.scaled_fp8_quant(x, scale=scale)
else:
q_input, input_scale = ops.scaled_fp8_quant(
x, use_per_token_if_dynamic=True)
q_input, input_scale = self.quant_fp8(x, scale=scale)

else:
# Not quantized, nothing to do with the input_scales, use as is
Expand All @@ -269,7 +274,10 @@ def apply_weights(
def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype:
if not self.quantized:
return params_dtype
return self._get_quant_dtype()

def _get_quant_dtype(self) -> torch.dtype:
assert self.quantized
assert self.weight_quant is not None
assert self.input_quant is not None

Expand Down
Loading