Skip to content

Commit ca9f59e

Browse files
committed
Fix and add compile ranges test
Signed-off-by: ilmarkov <[email protected]>
1 parent d7600ec commit ca9f59e

File tree

4 files changed

+94
-31
lines changed

4 files changed

+94
-31
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import torch
4+
from torch import nn
5+
from torch.library import Library
6+
7+
from vllm.compilation.counter import compilation_counter
8+
from vllm.compilation.decorators import support_torch_compile
9+
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
10+
set_current_vllm_config)
11+
from vllm.forward_context import set_forward_context
12+
from vllm.utils import direct_register_custom_op
13+
14+
# create a library to hold the custom op
15+
silly_lib = Library("silly", "FRAGMENT") # noqa
16+
17+
BATCH_SIZE = 64
18+
MLP_SIZE = 128
19+
20+
21+
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
22+
out: torch.Tensor) -> None:
23+
out.copy_(q)
24+
out += k
25+
out += v
26+
27+
28+
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
29+
out: torch.Tensor) -> None:
30+
return
31+
32+
33+
direct_register_custom_op(
34+
op_name="attention",
35+
op_func=silly_attention,
36+
mutates_args=["out"],
37+
fake_impl=silly_attention_fake,
38+
target_lib=silly_lib,
39+
)
40+
41+
42+
@support_torch_compile
43+
class TestModel(nn.Module):
44+
45+
def __init__(self,
46+
*,
47+
vllm_config: VllmConfig,
48+
prefix: str = '',
49+
**kwargs) -> None:
50+
super().__init__()
51+
52+
def forward(self, x: torch.Tensor) -> torch.Tensor:
53+
x = x + x
54+
attn_output = torch.empty_like(x)
55+
torch.ops.silly.attention(x, x, x, attn_output)
56+
x = attn_output
57+
x = x * 3
58+
return x
59+
60+
61+
@torch.inference_mode
62+
def run_model(vllm_config: VllmConfig, model: nn.Module,
63+
batch_sizes: list[int]):
64+
with set_forward_context({}, vllm_config=vllm_config):
65+
model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
66+
for batch_size in batch_sizes:
67+
model(torch.randn(batch_size, MLP_SIZE).cuda())
68+
69+
70+
def test_compile_ranges():
71+
vllm_config = VllmConfig(compilation_config=CompilationConfig(
72+
level=CompilationLevel.PIECEWISE,
73+
compile_ranges_split_points=[8, 32],
74+
))
75+
76+
with set_current_vllm_config(vllm_config):
77+
model = TestModel(vllm_config=vllm_config, prefix='').eval().cuda()
78+
batch_sizes = [1, 16, 48]
79+
# A has support_torch_compile
80+
with compilation_counter.expect(
81+
num_graphs_seen=1,
82+
num_piecewise_graphs_seen=1,
83+
num_backend_compilations=4,
84+
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
85+
):
86+
run_model(vllm_config, model, batch_sizes)

tests/compile/test_fusion_all_reduce.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from .backend import TestBackend
2828

2929

30-
def finisher(hidden_states):
30+
def maybe_dummy_quant(hidden_states):
3131
custom_ops = get_current_vllm_config().compilation_config.custom_ops
3232
if not custom_ops or "+quant_fp8" not in custom_ops:
3333
# Hack: use dynamic fp8 quantization to
@@ -53,7 +53,7 @@ def forward(self, hidden_states, residual):
5353

5454
hidden_states = self.norm(all_reduce)
5555

56-
hidden_states = finisher(hidden_states)
56+
hidden_states = maybe_dummy_quant(hidden_states)
5757

5858
return hidden_states
5959

@@ -80,7 +80,7 @@ def forward(self, hidden_states, residual):
8080
# Hack: use dynamic fp8 quantization to
8181
# suppress torch.compile optimizations
8282
# that prevent pattern matching
83-
hidden_states = finisher(hidden_states)
83+
hidden_states = maybe_dummy_quant(hidden_states)
8484
return hidden_states, residual
8585

8686
def ops_in_model_after(self):
@@ -122,7 +122,7 @@ def forward(self, hidden_states, residual):
122122
all_reduce = tensor_model_parallel_all_reduce(view)
123123
norm_output, residual_output = self.norm(all_reduce, residual)
124124
output, _ = self.quant_fp8(norm_output, self.scale)
125-
hidden_states = finisher(output.to(hidden_states.dtype))
125+
hidden_states = maybe_dummy_quant(output.to(hidden_states.dtype))
126126
return hidden_states, residual_output
127127

128128
def ops_in_model_after(self):

vllm/compilation/collective_fusion.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from torch._inductor.pattern_matcher import PatternMatcherPass
1111
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
1212

13-
import vllm.envs as envs
1413
from vllm.config import VllmConfig, set_current_vllm_config
1514
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
1615
from vllm.distributed.parallel_state import (
@@ -398,31 +397,6 @@ def __call__(self, graph: fx.Graph):
398397
if flashinfer_comm is not None:
399398
_FI_WORKSPACE_TENSOR = None
400399

401-
MiB = 1024 * 1024
402-
# Max size of the input tensor per world size
403-
# to use flashinfer fused allreduce
404-
_FI_MAX_SIZES = {
405-
2: 64 * MiB, # 64MB
406-
4: MiB, # 1MB
407-
6: MiB // 2, # 512KB
408-
8: MiB // 2, # 512KB
409-
}
410-
411-
try:
412-
_FI_MAX_SIZES.update({
413-
int(k): int(float(v) * MiB)
414-
for k, v in
415-
envs.VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB.items()
416-
})
417-
except Exception as e:
418-
raise ValueError(
419-
"Failed to parse VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB: "
420-
+ str(e)) from e
421-
422-
# opt for a more conservative default value
423-
# when world size is not in _FI_MAX_SIZES
424-
_DEFAULT_FI_MAX_SIZE = MiB // 2
425-
426400
def call_trtllm_fused_allreduce_norm(
427401
allreduce_in: torch.Tensor,
428402
residual: torch.Tensor,

vllm/utils/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@
7272

7373
import vllm.envs as envs
7474
from vllm.logger import enable_trace_function_call, init_logger
75-
from vllm.platforms import current_platform
7675
from vllm.ray.lazy_utils import is_in_ray_actor
7776

7877
if TYPE_CHECKING:
@@ -128,6 +127,10 @@ def flashinfer_max_size(world_size: int, config: VllmConfig) -> Optional[int]:
128127
allreduce fusion for the given world size. Falls back to
129128
conservative defaults if the world size is not specified in config.
130129
"""
130+
131+
# import here to avoid circular dependencies
132+
from vllm.platforms import current_platform
133+
131134
device_capability = current_platform.get_device_capability(
132135
).as_version_str()
133136
max_sizes = _FI_ALLREDUCE_MAX_INPUT_SIZES.get(device_capability, {})

0 commit comments

Comments
 (0)