Skip to content

Commit d7600ec

Browse files
committed
Update limits
Signed-off-by: ilmarkov <[email protected]>
1 parent 3db307d commit d7600ec

File tree

6 files changed

+156
-93
lines changed

6 files changed

+156
-93
lines changed

benchmarks/kernels/benchmark_fused_collective.py

Lines changed: 55 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,13 @@
6464
FP8_DTYPE = current_platform.fp8_dtype()
6565
MiB = 1024 * 1024
6666

67-
# FlashInfer max sizes per world size (from collective_fusion.py)
67+
# FlashInfer max sizes per world size
68+
# Enable 64MB for 2, 4, 8 world sizes to verify large input sizes
69+
# use --disable-oneshot to disable oneshot mode for very large input sizes
6870
_FI_MAX_SIZES = {
6971
2: 64 * MiB, # 64MB
70-
4: 32 * MiB, # 32MB
71-
6: 32 * MiB, # 32MB
72-
8: 32 * MiB, # 32MB
72+
4: 64 * MiB, # 64MB
73+
8: 64 * MiB, # 64MB
7374
}
7475

7576
# Global workspace tensor for FlashInfer
@@ -186,7 +187,7 @@ def flashinfer_fused_allreduce_rmsnorm(
186187
allreduce_out=None,
187188
quant_out=None,
188189
scale_out=None,
189-
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED,
190+
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4_,
190191
scale_factor=None,
191192
use_oneshot=use_oneshot,
192193
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
@@ -228,7 +229,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant(
228229
allreduce_out=None,
229230
quant_out=quant_out,
230231
scale_out=None,
231-
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED,
232+
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
232233
scale_factor=scale_factor,
233234
use_oneshot=use_oneshot,
234235
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
@@ -271,7 +272,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant(
271272
allreduce_out=None,
272273
quant_out=quant_out,
273274
scale_out=output_scale,
274-
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED,
275+
layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4,
275276
scale_factor=input_global_scale,
276277
use_oneshot=use_oneshot,
277278
**allreduce_params.get_trtllm_fused_allreduce_kwargs(),
@@ -579,6 +580,7 @@ def run_benchmarks(
579580
use_residual: bool,
580581
allreduce_params: Optional[FlashInferFusedAllReduceParams],
581582
quant_mode: str = "all",
583+
disable_oneshot: bool = False,
582584
):
583585
"""Run all benchmarks for given configuration.
584586
@@ -638,17 +640,18 @@ def run_benchmarks(
638640
# FlashInfer Fused AllReduce + RMSNorm Oneshot
639641
if flashinfer_comm is not None and allreduce_params is not None:
640642
try:
641-
time_ms = benchmark_operation(
642-
flashinfer_fused_allreduce_rmsnorm,
643-
input_tensor,
644-
residual=residual,
645-
norm_out=norm_out,
646-
rms_gamma=rms_gamma,
647-
rms_eps=rms_eps,
648-
allreduce_params=allreduce_params,
649-
use_oneshot=True,
650-
)
651-
results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = time_ms
643+
if not disable_oneshot:
644+
time_ms = benchmark_operation(
645+
flashinfer_fused_allreduce_rmsnorm,
646+
input_tensor,
647+
residual=residual,
648+
norm_out=norm_out,
649+
rms_gamma=rms_gamma,
650+
rms_eps=rms_eps,
651+
allreduce_params=allreduce_params,
652+
use_oneshot=True,
653+
)
654+
results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = time_ms
652655
except Exception as e:
653656
logger.error("FlashInfer Fused AllReduce+RMSNorm Oneshot failed: %s", e)
654657
results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = float("inf")
@@ -712,21 +715,22 @@ def run_benchmarks(
712715
# FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot
713716
if flashinfer_comm is not None and allreduce_params is not None:
714717
try:
715-
time_ms = benchmark_operation(
716-
flashinfer_fused_allreduce_rmsnorm_fp8_quant,
717-
input_tensor,
718-
norm_out=norm_out,
719-
residual=residual,
720-
rms_gamma=rms_gamma,
721-
rms_eps=rms_eps,
722-
scale_factor=scale_fp8,
723-
quant_out=quant_out_fp8,
724-
allreduce_params=allreduce_params,
725-
use_oneshot=True,
726-
)
727-
results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot"] = (
728-
time_ms
729-
)
718+
if not disable_oneshot:
719+
time_ms = benchmark_operation(
720+
flashinfer_fused_allreduce_rmsnorm_fp8_quant,
721+
input_tensor,
722+
norm_out=norm_out,
723+
residual=residual,
724+
rms_gamma=rms_gamma,
725+
rms_eps=rms_eps,
726+
scale_factor=scale_fp8,
727+
quant_out=quant_out_fp8,
728+
allreduce_params=allreduce_params,
729+
use_oneshot=True,
730+
)
731+
results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot"] = (
732+
time_ms
733+
)
730734
except Exception as e:
731735
logger.error(
732736
"FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s",
@@ -802,22 +806,23 @@ def run_benchmarks(
802806
# FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot
803807
if flashinfer_comm is not None and allreduce_params is not None:
804808
try:
805-
time_ms = benchmark_operation(
806-
flashinfer_fused_allreduce_rmsnorm_fp4_quant,
807-
input_tensor,
808-
residual=residual,
809-
norm_out=norm_out,
810-
rms_gamma=rms_gamma,
811-
rms_eps=rms_eps,
812-
input_global_scale=scale_fp4,
813-
allreduce_params=allreduce_params,
814-
quant_out=fp4_quant_out,
815-
output_scale=fp4_output_scale,
816-
use_oneshot=True,
817-
)
818-
results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot"] = (
819-
time_ms
820-
)
809+
if not disable_oneshot:
810+
time_ms = benchmark_operation(
811+
flashinfer_fused_allreduce_rmsnorm_fp4_quant,
812+
input_tensor,
813+
residual=residual,
814+
norm_out=norm_out,
815+
rms_gamma=rms_gamma,
816+
rms_eps=rms_eps,
817+
input_global_scale=scale_fp4,
818+
allreduce_params=allreduce_params,
819+
quant_out=fp4_quant_out,
820+
output_scale=fp4_output_scale,
821+
use_oneshot=True,
822+
)
823+
results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot"] = (
824+
time_ms
825+
)
821826
except Exception as e:
822827
logger.error(
823828
"FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s",
@@ -1224,6 +1229,7 @@ def main():
12241229
use_residual,
12251230
allreduce_params,
12261231
quant_mode=quant_mode,
1232+
disable_oneshot=args.disable_oneshot,
12271233
)
12281234

12291235
# Store results for markdown export

vllm/compilation/collective_fusion.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from vllm.model_executor.layers.quantization.utils.quant_utils import (
2121
GroupShape)
2222
from vllm.platforms import current_platform
23-
from vllm.utils import direct_register_custom_op
23+
from vllm.utils import (_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES,
24+
direct_register_custom_op, flashinfer_max_size)
2425
from .inductor_pass import enable_fake_mode
2526
from .vllm_inductor_pass import VllmInductorPass
2627

@@ -439,6 +440,23 @@ def call_trtllm_fused_allreduce_norm(
439440
scale_out: Optional[torch.Tensor] = None,
440441
scale_factor: Optional[torch.Tensor] = None,
441442
) -> None:
443+
num_tokens, hidden_size = allreduce_in.shape
444+
element_size = allreduce_in.element_size()
445+
current_tensor_size = num_tokens * hidden_size * element_size
446+
max_tensor_size = max_token_num * hidden_size * element_size
447+
assert current_tensor_size <= max_tensor_size, \
448+
f"Current tensor size {current_tensor_size} is larger than " \
449+
f"max token num {max_token_num} * hidden size {hidden_size} * " \
450+
f"element size {element_size}"
451+
device_capability = current_platform.get_device_capability(
452+
).as_version_str()
453+
max_sizes = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES.get(device_capability, {})
454+
# Get one shot input size limit for the current world size
455+
max_one_shot_size = max_sizes.get(world_size, None)
456+
# Use one shot if no max size is specified
457+
use_oneshot = max_one_shot_size is None or \
458+
current_tensor_size <= max_one_shot_size
459+
442460
assert (
443461
_FI_WORKSPACE_TENSOR
444462
is not None), "Flashinfer must be enabled when using flashinfer"
@@ -465,7 +483,7 @@ def call_trtllm_fused_allreduce_norm(
465483
hidden_dim=allreduce_in.shape[-1],
466484
workspace_ptrs=_FI_WORKSPACE_TENSOR,
467485
launch_with_pdl=launch_with_pdl,
468-
use_oneshot=True,
486+
use_oneshot=use_oneshot,
469487
trigger_completion_at_end=trigger_completion_at_end,
470488
fp32_acc=fp32_acc,
471489
pattern_code=pattern_code,
@@ -1458,24 +1476,28 @@ def __init__(self, config: VllmConfig):
14581476
"Flashinfer is not installed or comm module not found, "
14591477
"skipping allreduce fusion pass")
14601478
return
1461-
# Check if the world size is supported
1462-
if self.tp_size not in _FI_MAX_SIZES:
1479+
max_size = flashinfer_max_size(self.tp_size, config)
1480+
if max_size is None:
1481+
# Flashinfer doesn't support current world size
14631482
logger.warning(
14641483
"Flashinfer allreduce fusion is not "
14651484
"supported for world size %s",
14661485
self.tp_size,
14671486
)
14681487
return
1469-
max_num_token = min(
1470-
_FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE) //
1471-
(self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)),
1472-
config.compilation_config.pass_config.
1473-
fi_allreduce_fusion_max_token_num)
1488+
element_size = 4 if use_fp32_lamport else 2
1489+
max_token_num = (max_size //
1490+
(self.hidden_dim * element_size))
1491+
# take the min to save workspace size and we'll never use more
1492+
# than max_num_batched_tokens anyways
1493+
max_token_num = min(max_token_num,
1494+
config.scheduler_config.max_num_batched_tokens)
1495+
14741496
self.ipc_handles, workspace_tensor = (
14751497
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
14761498
tp_rank=rank,
14771499
tp_size=self.tp_size,
1478-
max_token_num=max_num_token,
1500+
max_token_num=max_token_num,
14791501
hidden_dim=self.hidden_dim,
14801502
group=self.group,
14811503
use_fp32_lamport=use_fp32_lamport,
@@ -1487,7 +1509,7 @@ def __init__(self, config: VllmConfig):
14871509
rank=rank,
14881510
world_size=self.tp_size,
14891511
use_fp32_lamport=use_fp32_lamport,
1490-
max_token_num=max_num_token,
1512+
max_token_num=max_token_num,
14911513
)
14921514
is_custom_ops = ("+rms_norm" in config.compilation_config.custom_ops,
14931515
"+quant_fp8" in config.compilation_config.custom_ops)

vllm/config/__init__.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,8 @@
4949
try_get_tokenizer_config, uses_mrope)
5050
from vllm.transformers_utils.s3_utils import S3Model
5151
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
52-
from vllm.utils import (_DEFAULT_FI_ALLREDUCE_MAX_INPUT_SIZE,
53-
_FI_ALLREDUCE_MAX_INPUT_SIZES,
54-
DEFAULT_MAX_NUM_BATCHED_TOKENS, LayerBlockType,
55-
LazyLoader, common_broadcastable_dtype, random_uuid)
52+
from vllm.utils import (LayerBlockType, LazyLoader, common_broadcastable_dtype,
53+
flashinfer_max_size, random_uuid)
5654

5755
if TYPE_CHECKING:
5856
from _typeshed import DataclassInstance
@@ -3879,13 +3877,15 @@ def _set_compile_ranges(self):
38793877
# Add the compile ranges for flashinfer
38803878
if compilation_config.pass_config.enable_fi_allreduce_fusion:
38813879
tp_size = self.parallel_config.tensor_parallel_size
3882-
max_size = _FI_ALLREDUCE_MAX_INPUT_SIZES.get(
3883-
tp_size, _DEFAULT_FI_ALLREDUCE_MAX_INPUT_SIZE)
3884-
max_token_num = max_size // (self.model_config.get_hidden_size() *
3885-
self.model_config.dtype.itemsize)
3886-
# We add 1 because the bounds checks in the compiler are exclusive
3887-
# and we want to include the max_token_num in the compile range
3888-
computed_compile_ranges_split_points.append(max_token_num + 1)
3880+
max_size = flashinfer_max_size(tp_size, self)
3881+
if max_size is not None:
3882+
max_token_num = max_size // (
3883+
self.model_config.get_hidden_size() *
3884+
self.model_config.dtype.itemsize)
3885+
# We add 1 because the bounds checks in the compiler are
3886+
# exclusive and we want to include the max_token_num in the
3887+
# compile range
3888+
computed_compile_ranges_split_points.append(max_token_num + 1)
38893889

38903890
if compilation_config.compile_ranges_split_points is not None:
38913891
for x in compilation_config.compile_ranges_split_points:

vllm/config/compilation.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,14 @@ class PassConfig:
8787
"""Whether to enable async TP."""
8888
enable_fi_allreduce_fusion: bool = False
8989
"""Whether to enable flashinfer allreduce fusion."""
90-
fi_allreduce_fusion_max_token_num: int = 16384
91-
"""Max number of tokens to used in flashinfer allreduce fusion."""
90+
fi_allreduce_fusion_max_size_mb: dict[int,
91+
float] = field(default_factory=dict)
92+
"""The thresholds of the communicated tensor sizes under which
93+
vllm should use flashinfer fused allreduce. Specified as a
94+
dictionary mapping each world size to the threshold in MB
95+
{ <world size>: <max size in mb> }
96+
Unspecified world sizes will fallback to
97+
{ 2: 32, 4: 32, 8: 2 }"""
9298

9399
# TODO(luka) better pass enabling system.
94100

vllm/envs.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import hashlib
5-
import json
65
import os
76
import sys
87
import tempfile
@@ -1059,16 +1058,6 @@ def get_vllm_port() -> Optional[int]:
10591058
"VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE":
10601059
lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")),
10611060

1062-
# Specifies the thresholds of the communicated tensor sizes under which
1063-
# vllm should use flashinfer fused allreduce. The variable should be a
1064-
# JSON with the following format:
1065-
# { <world size>: <max size in mb> }
1066-
# Unspecified world sizes will fallback to
1067-
# { 2: 64, 4: 1, <everything else>: 0.5 }
1068-
"VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB":
1069-
lambda: json.loads(os.getenv(
1070-
"VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB", "{}")),
1071-
10721061
# MoE routing strategy selector.
10731062
# See `RoutingSimulator.get_available_strategies()` # for available
10741063
# strategies.

0 commit comments

Comments
 (0)