Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,270 changes: 1,270 additions & 0 deletions benchmarks/kernels/benchmark_fused_collective.py

Large diffs are not rendered by default.

86 changes: 86 additions & 0 deletions tests/compile/test_compile_ranges.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch import nn
from torch.library import Library

from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
set_current_vllm_config)
from vllm.forward_context import set_forward_context
from vllm.utils import direct_register_custom_op

# create a library to hold the custom op
silly_lib = Library("silly", "FRAGMENT") # noqa

BATCH_SIZE = 64
MLP_SIZE = 128


def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
out.copy_(q)
out += k
out += v


def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return


direct_register_custom_op(
op_name="attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
)


@support_torch_compile
class TestModel(nn.Module):

def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + x
attn_output = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, attn_output)
x = attn_output
x = x * 3
return x


@torch.inference_mode
def run_model(vllm_config: VllmConfig, model: nn.Module,
batch_sizes: list[int]):
with set_forward_context({}, vllm_config=vllm_config):
model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
for batch_size in batch_sizes:
model(torch.randn(batch_size, MLP_SIZE).cuda())


def test_compile_ranges():
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
compile_ranges_split_points=[8, 32],
))

with set_current_vllm_config(vllm_config):
model = TestModel(vllm_config=vllm_config, prefix='').eval().cuda()
batch_sizes = [1, 16, 48]
# A has support_torch_compile
with compilation_counter.expect(
num_graphs_seen=1,
num_piecewise_graphs_seen=1,
num_backend_compilations=4,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
run_model(vllm_config, model, batch_sizes)
143 changes: 97 additions & 46 deletions tests/compile/test_fusion_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
import torch

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.compilation.collective_fusion import AllReduceFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
ModelConfig, PassConfig, VllmConfig)
ModelConfig, PassConfig, VllmConfig,
get_current_vllm_config, set_current_vllm_config)
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
Expand All @@ -24,7 +27,19 @@
from .backend import TestBackend


def maybe_dummy_quant(hidden_states):
custom_ops = get_current_vllm_config().compilation_config.custom_ops
if not custom_ops or "+quant_fp8" not in custom_ops:
# Hack: use dynamic fp8 quantization to
# suppress torch.compile optimizations
# that prevent pattern matching
return ops.scaled_fp8_quant(hidden_states)
else:
return hidden_states


class TestAllReduceRMSNormModel(torch.nn.Module):
pattern_code = 1

def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
Expand All @@ -33,10 +48,14 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
self.norm = RMSNorm(hidden_size, eps)

def forward(self, hidden_states, residual):
view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(view)
norm = self.norm(all_reduce)
return norm
# view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(hidden_states)

hidden_states = self.norm(all_reduce)

hidden_states = maybe_dummy_quant(hidden_states)

return hidden_states

def ops_in_model_before(self):
return [torch.ops.vllm.all_reduce.default]
Expand All @@ -46,6 +65,7 @@ def ops_in_model_after(self):


class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
pattern_code = 1

def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
Expand All @@ -56,49 +76,66 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
def forward(self, hidden_states, residual):
view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(view)
norm, _ = self.norm(all_reduce, residual)
return norm

def ops_in_model_before(self):
return [torch.ops.vllm.all_reduce.default]
hidden_states, residual = self.norm(all_reduce, residual)
# Hack: use dynamic fp8 quantization to
# suppress torch.compile optimizations
# that prevent pattern matching
hidden_states = maybe_dummy_quant(hidden_states)
return hidden_states, residual

def ops_in_model_after(self):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]

def ops_in_model_before(self):
return [
torch.ops.vllm.all_reduce.default,
]


class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
pattern_code = 2

def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = RMSNorm(hidden_size, eps)
self.quant_fp8 = QuantFP8(static=True,
group_shape=GroupShape.PER_TENSOR)
self.scale = torch.rand(1, dtype=torch.float32)
self.output = torch.empty((token_num, hidden_size),
dtype=torch.float32)
dtype=current_platform.fp8_dtype())

def _quant_fp8_wrapper(x, scale):
torch.ops._C.static_scaled_fp8_quant(self.output, x, scale)
return self.output, scale

vllm_config = get_current_vllm_config()
if "+quant_fp8" in vllm_config.compilation_config.custom_ops:
# Need to use static_scaled_fp8_quant instead of QuantFP8
# due to failure in TestBackend with copying graph
self.quant_fp8 = _quant_fp8_wrapper
else:
self.quant_fp8 = QuantFP8(static=True,
group_shape=GroupShape.PER_TENSOR)
self.scale = torch.rand(1, dtype=torch.float32)

def forward(self, hidden_states, residual):
view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(view)
norm_output, residual_output = self.norm(all_reduce, residual)
torch.ops._C.static_scaled_fp8_quant(self.output,
norm_output.contiguous(),
self.scale)
return self.output, residual_output
output, _ = self.quant_fp8(norm_output, self.scale)
hidden_states = maybe_dummy_quant(output.to(hidden_states.dtype))
return hidden_states, residual_output

def ops_in_model_after(self):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]

def ops_in_model_before(self):
return [
torch.ops.vllm.all_reduce.default,
torch.ops._C.static_scaled_fp8_quant.default
]


class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
pattern_code = 3

def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
Expand Down Expand Up @@ -131,7 +168,6 @@ def ops_in_model_after(self):
def ops_in_model_before(self):
return [
torch.ops.vllm.all_reduce.default,
torch.ops._C.scaled_fp4_quant.default
]


Expand All @@ -142,9 +178,12 @@ def ops_in_model_before(self):
TestAllReduceRMSNormModel,
TestAllReduceFusedAddRMSNormModel,
TestAllReduceFusedAddRMSNormStaticQuantFP8Model,
# TODO: Enable with torch==2.8.0
# TODO: Enable with flashinfer v0.3.0
# TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
])
@pytest.mark.parametrize(
"custom_ops",
[[], ["+rms_norm"], ["+quant_fp8"], ["+rms_norm", "+quant_fp8"]])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [8])
@pytest.mark.parametrize("hidden_size", [16])
Expand All @@ -157,19 +196,23 @@ def ops_in_model_before(self):
reason="flashinfer is not found or flashinfer "
"is not compiled with trtllm_allreduce_fusion")
def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module,
batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype):
custom_ops: list[str], batch_size: int,
seq_len: int, hidden_size: int,
dtype: torch.dtype):
num_processes = 2
if (test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
and not current_platform.has_device_capability(100)):
pytest.skip("Skip as nvfp4 is only supported on "
"devices with compute capability 10.0 (Blackwell)")
if (test_model != TestAllReduceFusedAddRMSNormStaticQuantFP8Model
and ("+quant_fp8" in custom_ops)):
pytest.skip()

def run_torch_spawn(fn, nprocs):
torch.multiprocessing.spawn(fn,
args=(num_processes, test_model,
batch_size, seq_len, hidden_size,
dtype),
dtype, custom_ops),
nprocs=nprocs)

run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes)
Expand All @@ -178,7 +221,8 @@ def run_torch_spawn(fn, nprocs):
def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
test_model_cls: torch.nn.Module,
batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype):
hidden_size: int, dtype: torch.dtype,
custom_ops: list[str]):
current_platform.seed_everything(0)

device = torch.device(f"cuda:{local_rank}")
Expand All @@ -198,8 +242,7 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
initialize_model_parallel(tensor_model_parallel_size=world_size)

vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
custom_ops=["+rms_norm", "+quant_fp8"]))
level=CompilationLevel.PIECEWISE, custom_ops=custom_ops))
vllm_config.compilation_config.pass_config = PassConfig(
enable_fi_allreduce_fusion=True, enable_noop=True)
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
Expand All @@ -211,22 +254,30 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
trust_remote_code=True,
dtype=dtype,
seed=42)

all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)

backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass)

token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num)

hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
residual = torch.randn((token_num, hidden_size), requires_grad=False)

compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states, residual)

backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
backend.check_after_ops(model.ops_in_model_after())
del all_reduce_fusion_pass
with set_current_vllm_config(vllm_config):
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)

backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass)

token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num)

hidden_states = torch.randn((token_num, hidden_size),
requires_grad=False)
residual = torch.randn((token_num, hidden_size),
dtype=torch.float32,
requires_grad=False)

compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states, residual)

backend.check_before_ops(model.ops_in_model_before(),
fully_replaced=False)
backend.check_after_ops(model.ops_in_model_after())
for node in find_op_nodes(
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default,
backend.graph_post_pass):
assert (
node.kwargs.get("pattern_code") == test_model_cls.pattern_code)
Loading