From edd1e999288b2fa68a2ea1cb09474e4fe5dee401 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 4 Sep 2025 04:29:31 -0700 Subject: [PATCH 1/4] Split original pr Signed-off-by: ilmarkov --- .../kernels/benchmark_fused_collective.py | 1270 ++++++++++++++++ tests/compile/test_fusion_all_reduce.py | 143 +- vllm/compilation/collective_fusion.py | 1318 +++++++++++++---- vllm/config/compilation.py | 63 +- 4 files changed, 2475 insertions(+), 319 deletions(-) create mode 100644 benchmarks/kernels/benchmark_fused_collective.py diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py new file mode 100644 index 000000000000..ea78875c62cf --- /dev/null +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -0,0 +1,1270 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Benchmark for FlashInfer fused collective operations vs standard operations. + +This benchmark compares: +1. FlashInfer's trtllm_allreduce_fusion (fused allreduce + rmsnorm + optional quant) +2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations + +Usage with torchrun: + torchrun --nproc_per_node=2 benchmark_fused_collective.py + +""" + +import argparse +import itertools +import os +import time +from typing import Optional + +import torch # type: ignore +import torch.distributed as dist # type: ignore + +from vllm.distributed import ( + get_tp_group, + tensor_model_parallel_all_reduce, +) +from vllm.distributed.parallel_state import ( + graph_capture, + init_distributed_environment, + initialize_model_parallel, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm # noqa +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 # noqa +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape # noqa +from vllm.platforms import current_platform # noqa + +RMS_NORM_OP = torch.ops._C.rms_norm +FUSED_ADD_RMS_NORM_OP = torch.ops._C.fused_add_rms_norm +RMS_NORM_STATIC_FP8_QUANT_OP = torch.ops._C.rms_norm_static_fp8_quant +FUSED_ADD_RMS_NORM_STATIC_FP8_QUANT_OP = ( + torch.ops._C.fused_add_rms_norm_static_fp8_quant +) +SCALED_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant + +logger = init_logger(__name__) + +# Try to import FlashInfer +try: + import flashinfer.comm as flashinfer_comm # type: ignore + + if not hasattr(flashinfer_comm, "trtllm_allreduce_fusion"): + flashinfer_comm = None + logger.warning( + "FlashInfer comm module found but missing trtllm_allreduce_fusion" + ) +except ImportError: + flashinfer_comm = None + logger.warning("FlashInfer not found, only benchmarking standard operations") + +# Constants +FP8_DTYPE = current_platform.fp8_dtype() +MiB = 1024 * 1024 + +# FlashInfer max sizes per world size +# Enable 64MB for 2, 4, 8 world sizes to verify large input sizes +# use --disable-oneshot to disable oneshot mode for very large input sizes +_FI_MAX_SIZES = { + 2: 64 * MiB, # 64MB + 4: 64 * MiB, # 64MB + 8: 64 * MiB, # 64MB +} + +# Global workspace tensor for FlashInfer +_FI_WORKSPACE_TENSOR = None + + +def setup_flashinfer_workspace( + world_size: int, + rank: int, + hidden_dim: int, + max_token_num: int, + use_fp32_lamport: bool = False, +): + """Setup FlashInfer workspace for fused allreduce operations.""" + global _FI_WORKSPACE_TENSOR + + if flashinfer_comm is None: + return None, None + + if world_size not in _FI_MAX_SIZES: + logger.warning("FlashInfer not supported for world size %s", world_size) + return None, None + + try: + # Create IPC workspace + ipc_handles, workspace_tensor = ( + flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( + tp_rank=rank, + tp_size=world_size, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + group=get_tp_group().device_group, + use_fp32_lamport=use_fp32_lamport, + ) + ) + + _FI_WORKSPACE_TENSOR = workspace_tensor + return ipc_handles, workspace_tensor + except Exception as e: + logger.error("Failed to setup FlashInfer workspace: %s", e) + return None, None + + +def cleanup_flashinfer_workspace(ipc_handles): + """Cleanup FlashInfer workspace.""" + if flashinfer_comm is None or ipc_handles is None: + return + + try: + group = get_tp_group().device_group + flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group) + except Exception as e: + logger.error("Failed to cleanup FlashInfer workspace: %s", e) + + +class FlashInferFusedAllReduceParams: + """Parameters for FlashInfer fused allreduce operations.""" + + def __init__( + self, + rank: int, + world_size: int, + use_fp32_lamport: bool = False, + max_token_num: int = 1024, + ): + self.rank = rank + self.world_size = world_size + self.use_fp32_lamport = use_fp32_lamport + self.trigger_completion_at_end = True + self.launch_with_pdl = True + self.fp32_acc = True + self.max_token_num = max_token_num + + def get_trtllm_fused_allreduce_kwargs(self): + return { + "world_rank": self.rank, + "world_size": self.world_size, + "launch_with_pdl": self.launch_with_pdl, + "trigger_completion_at_end": self.trigger_completion_at_end, + "fp32_acc": self.fp32_acc, + } + + +def flashinfer_fused_allreduce_rmsnorm( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rms_gamma: torch.Tensor, + rms_eps: float, + allreduce_params: "FlashInferFusedAllReduceParams", + use_oneshot: bool, + norm_out: Optional[torch.Tensor] = None, +): + """FlashInfer fused allreduce + rmsnorm operation.""" + if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: + raise RuntimeError("FlashInfer not available or workspace not initialized") + + if norm_out is None: + norm_out = input_tensor + residual_out = residual + else: + residual_out = input_tensor + + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + token_num=input_tensor.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + hidden_dim=input_tensor.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, + allreduce_out=None, + quant_out=None, + scale_out=None, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4_, + scale_factor=None, + use_oneshot=use_oneshot, + **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + +def flashinfer_fused_allreduce_rmsnorm_fp8_quant( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rms_gamma: torch.Tensor, + rms_eps: float, + scale_factor: torch.Tensor, + allreduce_params: FlashInferFusedAllReduceParams, + use_oneshot: bool = True, + norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, +): + """FlashInfer fused allreduce + rmsnorm + FP8 quantization.""" + if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: + raise RuntimeError("FlashInfer not available or workspace not initialized") + + if norm_out is None: + norm_out = input_tensor + residual_out = residual + else: + residual_out = input_tensor + + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + token_num=input_tensor.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + hidden_dim=input_tensor.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant, + allreduce_out=None, + quant_out=quant_out, + scale_out=None, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=scale_factor, + use_oneshot=use_oneshot, + **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + +def flashinfer_fused_allreduce_rmsnorm_fp4_quant( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rms_gamma: torch.Tensor, + rms_eps: float, + input_global_scale: torch.Tensor, + allreduce_params: FlashInferFusedAllReduceParams, + quant_out: torch.Tensor, + use_oneshot: bool, + output_scale: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, +): + """FlashInfer fused allreduce + rmsnorm + FP4 quantization.""" + if flashinfer_comm is None or _FI_WORKSPACE_TENSOR is None: + raise RuntimeError("FlashInfer not available or workspace not initialized") + + if norm_out is None: + norm_out = input_tensor + residual_out = residual + else: + residual_out = input_tensor + + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + token_num=input_tensor.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + hidden_dim=input_tensor.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant, + allreduce_out=None, + quant_out=quant_out, + scale_out=output_scale, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=input_global_scale, + use_oneshot=use_oneshot, + **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + +def standard_allreduce_rmsnorm( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rms_gamma: torch.Tensor, + rms_eps: float, + norm_out: Optional[torch.Tensor] = None, +): + """Standard allreduce + rmsnorm operations.""" + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + # Then RMS norm + if residual is not None: + # Fused add + RMS norm + FUSED_ADD_RMS_NORM_OP(allreduce_out, residual, rms_gamma, rms_eps) + else: + # Just RMS norm + if norm_out is None: + norm_out = torch.empty_like(allreduce_out) + RMS_NORM_OP(norm_out, allreduce_out, rms_gamma, rms_eps) + + +def standard_allreduce_rmsnorm_fp8_quant( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rms_gamma: torch.Tensor, + rms_eps: float, + scale_factor: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, +): + """Standard allreduce + rmsnorm + FP8 quantization.""" + if quant_out is None: + quant_out = torch.empty_like(input_tensor, dtype=FP8_DTYPE) + + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + + # Then fused RMS norm + FP8 quantization + if residual is not None: + FUSED_ADD_RMS_NORM_STATIC_FP8_QUANT_OP( + quant_out, allreduce_out, residual, rms_gamma, scale_factor, rms_eps + ) + return quant_out, residual + else: + RMS_NORM_STATIC_FP8_QUANT_OP( + quant_out, allreduce_out, rms_gamma, scale_factor, rms_eps + ) + return quant_out + + +def standard_allreduce_rmsnorm_fp4_quant( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rms_gamma: torch.Tensor, + rms_eps: float, + input_global_scale: torch.Tensor, + quant_out: torch.Tensor, + output_scale: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, +): + """Standard allreduce + rmsnorm + FP4 quantization.""" + + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + + # Then RMS norm + if residual is not None: + FUSED_ADD_RMS_NORM_OP(allreduce_out, residual, rms_gamma, rms_eps) + quant_input = allreduce_out + residual_out = residual + else: + if norm_out is None: + norm_out = torch.empty_like(allreduce_out) + RMS_NORM_OP(norm_out, allreduce_out, rms_gamma, rms_eps) + quant_input = norm_out + residual_out = allreduce_out + + # Finally FP4 quantization + SCALED_FP4_QUANT_OP(quant_out, quant_input, output_scale, input_global_scale) + if residual is not None: + return quant_out, residual_out, output_scale + else: + return quant_out, norm_out + + +def standard_allreduce_rmsnorm_native( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rmsnorm_layer: RMSNorm, + norm_out: Optional[torch.Tensor] = None, +): + """Standard allreduce + rmsnorm operations using native RMSNorm forward.""" + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + # Apply native RMSNorm + if residual is not None: + result = rmsnorm_layer.forward_native(allreduce_out, residual) + return result # Returns (norm_out, residual_out) + else: + result = rmsnorm_layer.forward_native(allreduce_out) + return result # Returns norm_out + + +def standard_allreduce_rmsnorm_fp8_quant_native( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rmsnorm_layer: RMSNorm, + quant_fp8_layer: QuantFP8, + scale_factor: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, +): + """Standard allreduce + rmsnorm + FP8 quantization using native implementations.""" + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + + # Apply native RMSNorm + if residual is not None: + norm_out, residual_out = rmsnorm_layer.forward_native(allreduce_out, residual) + else: + norm_out = rmsnorm_layer.forward_native(allreduce_out) + residual_out = allreduce_out + + # Apply native FP8 quantization + quant_out, _ = quant_fp8_layer.forward_native(norm_out, scale=scale_factor) + + if residual is not None: + return quant_out, residual_out + else: + return quant_out + + +def standard_allreduce_rmsnorm_fp4_quant_native( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rmsnorm_layer: RMSNorm, + input_global_scale: torch.Tensor, + quant_out: torch.Tensor, + output_scale: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, +): + """Standard allreduce + rmsnorm + FP4 quantization using native RMSNorm.""" + # All-reduce first + allreduce_out = tensor_model_parallel_all_reduce(input_tensor) + + # Apply native RMSNorm + if residual is not None: + norm_out, residual_out = rmsnorm_layer.forward_native(allreduce_out, residual) + quant_input = norm_out + else: + norm_out = rmsnorm_layer.forward_native(allreduce_out) + quant_input = norm_out + residual_out = allreduce_out + + # Apply FP4 quantization (still using fused CUDA op as there's no native FP4) + SCALED_FP4_QUANT_OP(quant_out, quant_input, output_scale, input_global_scale) + + if residual is not None: + return quant_out, residual_out, output_scale + else: + return quant_out, norm_out + + +# Compiled versions of native functions +@torch.compile +def standard_allreduce_rmsnorm_native_compiled( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rmsnorm_layer: RMSNorm, + norm_out: Optional[torch.Tensor] = None, +): + """Compiled version of standard allreduce + rmsnorm.""" + return standard_allreduce_rmsnorm_native( + input_tensor, residual, rmsnorm_layer, norm_out + ) + + +@torch.compile +def standard_allreduce_rmsnorm_fp8_quant_native_compiled( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rmsnorm_layer: RMSNorm, + quant_fp8_layer: QuantFP8, + scale_factor: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, +): + """Compiled version of standard allreduce + rmsnorm + FP8 quantization.""" + return standard_allreduce_rmsnorm_fp8_quant_native( + input_tensor, + residual, + rmsnorm_layer, + quant_fp8_layer, + scale_factor, + norm_out, + quant_out, + ) + + +@torch.compile +def standard_allreduce_rmsnorm_fp4_quant_native_compiled( + input_tensor: torch.Tensor, + residual: Optional[torch.Tensor], + rmsnorm_layer: RMSNorm, + input_global_scale: torch.Tensor, + quant_out: torch.Tensor, + output_scale: torch.Tensor, + norm_out: Optional[torch.Tensor] = None, +): + """Compiled version of standard allreduce + rmsnorm + FP4 quantization.""" + return standard_allreduce_rmsnorm_fp4_quant_native( + input_tensor, + residual, + rmsnorm_layer, + input_global_scale, + quant_out, + output_scale, + norm_out, + ) + + +def create_test_tensors( + seq_len: int, hidden_dim: int, dtype: torch.dtype, use_residual: bool = True +): + """Create test tensors for benchmarking.""" + input_tensor = torch.randn(seq_len, hidden_dim, dtype=dtype) + residual = ( + torch.randn_like(input_tensor) + if use_residual + else torch.zeros_like(input_tensor) + ) + rms_gamma = torch.ones(hidden_dim, dtype=dtype) + norm_out = None if use_residual else torch.empty_like(input_tensor) + + # Quantization scales + scale_fp8 = torch.tensor(1.0, dtype=torch.float32) + scale_fp4 = torch.tensor(1.0, dtype=torch.float32) + quant_out_fp8 = torch.empty_like(input_tensor, dtype=FP8_DTYPE) + # Pre-allocate FP4 output tensors (to avoid allocation overhead in benchmarks) + fp4_quant_out = torch.empty((seq_len, hidden_dim // 2), dtype=torch.uint8) + fp4_output_scale = torch.empty((128, 4), dtype=torch.int32) + + return ( + input_tensor, + norm_out, + residual, + rms_gamma, + scale_fp8, + quant_out_fp8, + scale_fp4, + fp4_quant_out, + fp4_output_scale, + ) + + +def benchmark_operation( + operation_func, *args, warmup: int = 5, trials: int = 20, **kwargs +): + """Benchmark a single operation using CUDA graphs.""" + # Warmup before graph capture + for _ in range(warmup): + operation_func(*args, **kwargs) + torch.cuda.synchronize() + + # Create CUDA graph + graph = torch.cuda.CUDAGraph() + num_op_per_cudagraph = 10 + + # Use vLLM's graph_capture to make tensor_model_parallel_all_reduce graph-safe + device = torch.device(f"cuda:{torch.cuda.current_device()}") + with graph_capture(device=device), torch.cuda.graph(graph): + for _ in range(num_op_per_cudagraph): + operation_func(*args, **kwargs) + + # Graph warmup + torch.cuda.synchronize() + for _ in range(warmup): + graph.replay() + + # Benchmark with CUDA graph + torch.cuda.synchronize() + start_time = time.perf_counter() + + for _ in range(trials // num_op_per_cudagraph): + # operation_func(*args, **kwargs) + graph.replay() + + torch.cuda.synchronize() + end_time = time.perf_counter() + + avg_time_ms = ((end_time - start_time) / trials) * 1000 + return avg_time_ms + + +def run_benchmarks( + seq_len: int, + hidden_dim: int, + dtype: torch.dtype, + use_residual: bool, + allreduce_params: Optional[FlashInferFusedAllReduceParams], + quant_mode: str = "all", + disable_oneshot: bool = False, +): + """Run all benchmarks for given configuration. + + Args: + quant_mode: "none", "fp8_only", "fp4_only", or "all" + """ + ( + input_tensor, + norm_out, + residual, + rms_gamma, + scale_fp8, + quant_out_fp8, + scale_fp4, + fp4_quant_out, + fp4_output_scale, + ) = create_test_tensors(seq_len, hidden_dim, dtype, use_residual) + + rms_eps = 1e-6 + results = {} + + # Create RMSNorm and QuantFP8 layers once for native benchmarks + rmsnorm_layer = RMSNorm(hidden_dim, eps=rms_eps, dtype=dtype) + rmsnorm_layer.weight.data = rms_gamma + quant_fp8_layer = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) + + if quant_mode in ["all", "none"]: + # Standard AllReduce + RMSNorm + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + ) + results["standard_allreduce_rmsnorm"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm failed: %s", e) + results["standard_allreduce_rmsnorm"] = float("inf") + + # Standard AllReduce + RMSNorm Native Compiled + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_native_compiled, + input_tensor, + residual=residual, + rmsnorm_layer=rmsnorm_layer, + norm_out=norm_out, + ) + results["standard_allreduce_rmsnorm_native_compiled"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm Native Compiled failed: %s", e) + results["standard_allreduce_rmsnorm_native_compiled"] = float("inf") + + # FlashInfer Fused AllReduce + RMSNorm Oneshot + if flashinfer_comm is not None and allreduce_params is not None: + try: + if not disable_oneshot: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + allreduce_params=allreduce_params, + use_oneshot=True, + ) + results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = time_ms + except Exception as e: + logger.error("FlashInfer Fused AllReduce+RMSNorm Oneshot failed: %s", e) + results["flashinfer_fused_allreduce_rmsnorm_oneshot"] = float("inf") + + # FlashInfer Fused AllReduce + RMSNorm Two-shot + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + allreduce_params=allreduce_params, + use_oneshot=False, + ) + results["flashinfer_fused_allreduce_rmsnorm_twoshot"] = time_ms + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm Two-shot failed: %s", e + ) + results["flashinfer_fused_allreduce_rmsnorm_twoshot"] = float("inf") + + if quant_mode in ["all", "fp8_only"]: + # Standard AllReduce + RMSNorm + FP8 Quant + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp8_quant, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_fp8, + quant_out=quant_out_fp8, + ) + results["standard_allreduce_rmsnorm_fp8_quant"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP8 failed: %s", e) + results["standard_allreduce_rmsnorm_fp8_quant"] = float("inf") + + # Standard AllReduce + RMSNorm + FP8 Quant Native Compiled + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp8_quant_native_compiled, + input_tensor, + residual=residual, + rmsnorm_layer=rmsnorm_layer, + quant_fp8_layer=quant_fp8_layer, + scale_factor=scale_fp8, + norm_out=norm_out, + quant_out=quant_out_fp8, + ) + results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP8 Native Compiled failed: %s", e) + results["standard_allreduce_rmsnorm_fp8_quant_native_compiled"] = float( + "inf" + ) + + # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot + if flashinfer_comm is not None and allreduce_params is not None: + try: + if not disable_oneshot: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp8_quant, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_fp8, + quant_out=quant_out_fp8, + allreduce_params=allreduce_params, + use_oneshot=True, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s", + e, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_oneshot"] = float( + "inf" + ) + # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Two-shot + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp8_quant, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_fp8, + quant_out=quant_out_fp8, + allreduce_params=allreduce_params, + use_oneshot=False, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_twoshot"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP8 Two-shot failed: %s", + e, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp8_quant_twoshot"] = float( + "inf" + ) + + if quant_mode in ["all", "fp4_only"]: + # Standard AllReduce + RMSNorm + FP4 Quant + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp4_quant, + input_tensor, + norm_out=norm_out, + residual=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + input_global_scale=scale_fp4, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + ) + results["standard_allreduce_rmsnorm_fp4_quant"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP4 failed: %s", e) + results["standard_allreduce_rmsnorm_fp4_quant"] = float("inf") + + # Standard AllReduce + RMSNorm + FP4 Quant Native Compiled + try: + time_ms = benchmark_operation( + standard_allreduce_rmsnorm_fp4_quant_native_compiled, + input_tensor, + residual=residual, + rmsnorm_layer=rmsnorm_layer, + input_global_scale=scale_fp4, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + norm_out=norm_out, + ) + results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = time_ms + except Exception as e: + logger.error("Standard AllReduce+RMSNorm+FP4 Native Compiled failed: %s", e) + results["standard_allreduce_rmsnorm_fp4_quant_native_compiled"] = float( + "inf" + ) + + # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot + if flashinfer_comm is not None and allreduce_params is not None: + try: + if not disable_oneshot: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp4_quant, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + input_global_scale=scale_fp4, + allreduce_params=allreduce_params, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + use_oneshot=True, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s", + e, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_oneshot"] = float( + "inf" + ) + + # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Two-shot + if flashinfer_comm is not None and allreduce_params is not None: + try: + time_ms = benchmark_operation( + flashinfer_fused_allreduce_rmsnorm_fp4_quant, + input_tensor, + residual=residual, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + input_global_scale=scale_fp4, + allreduce_params=allreduce_params, + quant_out=fp4_quant_out, + output_scale=fp4_output_scale, + use_oneshot=False, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = ( + time_ms + ) + except Exception as e: + logger.error( + "FlashInfer Fused AllReduce+RMSNorm+FP4 Two-shot failed: %s", + e, + ) + results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = float( + "inf" + ) + + return results + + +def prepare_results_with_speedups(results_dict): + """Prepare results with speedup calculations based on dynamic baseline selection.""" + prepared_results = [] + + # Determine the fastest baseline for each operation type + def get_fastest_baseline(op_name, results_dict): + """Get the fastest baseline between standard and native_compiled versions.""" + if "fp8_quant" in op_name: + candidates = [ + "standard_allreduce_rmsnorm_fp8_quant", + "standard_allreduce_rmsnorm_fp8_quant_native_compiled", + ] + elif "fp4_quant" in op_name: + candidates = [ + "standard_allreduce_rmsnorm_fp4_quant", + "standard_allreduce_rmsnorm_fp4_quant_native_compiled", + ] + else: + candidates = [ + "standard_allreduce_rmsnorm", + "standard_allreduce_rmsnorm_native_compiled", + ] + + # Find the fastest among available candidates + fastest_time = float("inf") + fastest_baseline = None + + for candidate in candidates: + if ( + candidate in results_dict + and results_dict[candidate] != float("inf") + and results_dict[candidate] < fastest_time + ): + fastest_time = results_dict[candidate] + fastest_baseline = candidate + + return fastest_baseline + + # Create dynamic baseline mapping + dynamic_baseline_mapping = {} + for op_name in results_dict: + if ( + op_name.startswith("flashinfer_") + or op_name.startswith("standard_") + and not op_name.endswith("_native_compiled") + ): + dynamic_baseline_mapping[op_name] = get_fastest_baseline( + op_name, results_dict + ) + + for op_name, time_ms in results_dict.items(): + if time_ms == float("inf"): + speedup_str = "FAILED" + time_str = "FAILED" + else: + time_str = f"{time_ms:.3f}" + # Find the appropriate baseline for this operation + baseline_op = dynamic_baseline_mapping.get(op_name) + if baseline_op and baseline_op in results_dict: + baseline_time = results_dict[baseline_op] + if baseline_time != float("inf") and baseline_time > 0: + speedup = baseline_time / time_ms + speedup_str = f"{speedup:.2f}x" + else: + speedup_str = "N/A" + else: + # For baseline operations, determine if this is the fastest baseline + if op_name.endswith("_native_compiled") or ( + op_name.startswith("standard_") + and not op_name.endswith("_native_compiled") + ): + fastest_baseline = get_fastest_baseline(op_name, results_dict) + if fastest_baseline == op_name: + speedup_str = "baseline" + else: + if fastest_baseline and fastest_baseline in results_dict: + baseline_time = results_dict[fastest_baseline] + if baseline_time != float("inf") and baseline_time > 0: + speedup = baseline_time / time_ms + speedup_str = f"{speedup:.2f}x" + else: + speedup_str = "N/A" + else: + speedup_str = "N/A" + else: + speedup_str = "N/A" + + prepared_results.append( + { + "operation": op_name, + "time_ms": time_ms, + "time_str": time_str, + "speedup_str": speedup_str, + } + ) + + return prepared_results + + +def print_results(results_dict, seq_len, hidden_dim, dtype, use_residual, quant_mode): + """Print benchmark results in a formatted table.""" + print(f"\n{'=' * 80}") + print(f"Results: seq_len={seq_len}, hidden_dim={hidden_dim}") + print( + f"dtype={dtype}, residual={'yes' if use_residual else 'no'}, " + f"quant_mode={quant_mode}" + ) + print(f"{'=' * 80}") + print(f"{'Operation':<50} {'Time (ms)':<12} {'Speedup':<10}") + print(f"{'-' * 80}") + + # Prepare results with speedup calculations + prepared_results = prepare_results_with_speedups(results_dict) + + for result in prepared_results: + if result["time_ms"] == float("inf"): + time_display = result["time_str"] + else: + time_display = f"{result['time_ms']:.3f}" + + print( + f"{result['operation']:<50} {time_display:<12} {result['speedup_str']:<10}" + ) + + +def format_results_markdown( + all_results: list[dict], world_size: int, args: argparse.Namespace +) -> str: + """Format all benchmark results as markdown.""" + markdown = f"""# FlashInfer Fused Collective Operations Benchmark Results + +**World Size:** {world_size} +**Hidden Dimension:** {args.hidden_dim} +**Warmup Iterations:** {args.warmup} +**Benchmark Trials:** {args.trials} +**Quantization Mode:** {all_results[0]["quant_mode"] if all_results else "N/A"} + +--- + +""" + + for result in all_results: + seq_len = result["seq_len"] + dtype = result["dtype"] + use_residual = result["use_residual"] + results_dict = result["results"] + + residual_str = "with residual" if use_residual else "no residual" + + markdown += f""" +## Configuration: seq_len={seq_len}, dtype={dtype}, {residual_str} + +| Operation | Time (ms) | Speedup | +|-----------|-----------|---------| +""" + + # Prepare results with speedup calculations + prepared_results = prepare_results_with_speedups(results_dict) + + for result in prepared_results: + # Format operation name for better readability + formatted_op_name = result["operation"].replace("_", " ").title() + markdown += f"| {formatted_op_name} | {result['time_str']} |" + markdown += f"{result['speedup_str']} |\n" + + markdown += "\n" + + return markdown + + +def save_results_to_file( + all_results: list[dict], world_size: int, args: argparse.Namespace, rank: int +): + """Save benchmark results to markdown file (only on rank 0).""" + if rank != 0: + return + + if not all_results: + logger.warning("No results to save") + return + + output_path = args.output_file + + try: + markdown_content = format_results_markdown(all_results, world_size, args) + + with open(output_path, "w") as f: + f.write(markdown_content) + + except Exception as e: + logger.error("Failed to save results to file: %s", e) + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark fused collective operations" + ) + parser.add_argument( + "--seq-lens", + type=int, + nargs="+", + default=[128, 512, 1024, 2048], + help="Sequence lengths to test", + ) + parser.add_argument( + "--hidden-dim", type=int, default=8192, help="Hidden dimension size" + ) + parser.add_argument( + "--dtypes", + type=str, + nargs="+", + default=["bfloat16"], + choices=["float16", "bfloat16", "float32"], + help="Data types to test", + ) + parser.add_argument( + "--no-residual", + action="store_true", + help="Skip residual connection tests", + ) + + # Quantization mode options (mutually exclusive with --no-quant) + quant_group = parser.add_mutually_exclusive_group() + quant_group.add_argument( + "--no-quant", action="store_true", help="Skip all quantization tests" + ) + quant_group.add_argument( + "--quant-fp8", action="store_true", help="Only run FP8 quantization tests" + ) + quant_group.add_argument( + "--quant-fp4", action="store_true", help="Only run FP4 quantization tests" + ) + quant_group.add_argument( + "--quant-all", + action="store_true", + help="Run all quantization tests (default)", + ) + + parser.add_argument( + "--disable-oneshot", + action="store_true", + help="Disable oneshot mode for FlashInfer operations", + ) + parser.add_argument( + "--warmup", type=int, default=5, help="Number of warmup iterations" + ) + parser.add_argument( + "--trials", type=int, default=20, help="Number of benchmark trials" + ) + parser.add_argument( + "--output-file", + type=str, + help="""Output file path for markdown results + (default: benchmark_results_.md) + """, + ) + + args = parser.parse_args() + + # Check if running with torchrun (required for collective operations) + if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ: + raise RuntimeError( + "Must run with torchrun for distributed benchmarking. " + "Example: torchrun --nproc_per_node=2 benchmark_fused_collective.py" + ) + + # Initialize distributed environment + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # Validate world size (must be > 1 for collective operations) + if world_size <= 1: + raise ValueError( + "World size must be > 1 for collective operations benchmarking. " + f"Current world size: {world_size}. Use torchrun with --nproc_per_node > 1." + ) + + # Determine quantization mode + if args.no_quant: + quant_mode = "none" + elif args.quant_fp8: + quant_mode = "fp8_only" + elif args.quant_fp4: + quant_mode = "fp4_only" + else: # args.quant_all or default + quant_mode = "all" + + if rank == 0: + logger.info("Running benchmark with world_size=%s, rank=%s", world_size, rank) + logger.info("Quantization mode: %s", quant_mode) + if flashinfer_comm is not None: + oneshot_status = "enabled" if not args.disable_oneshot else "disabled" + logger.info( + "FlashInfer available - will benchmark fused operations (oneshot: %s)", + oneshot_status, + ) + else: + logger.info( + "FlashInfer not available - only benchmarking standard operations" + ) + + # Convert dtype strings to torch dtypes + dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + } + dtypes = [dtype_map[dt] for dt in args.dtypes] + + # Test configurations + residual_options = [True] if not args.no_residual else [False] + if not args.no_residual: + residual_options.append(False) + + configs = list(itertools.product(args.seq_lens, dtypes, residual_options)) + + # Setup FlashInfer workspace if available + ipc_handles = None + allreduce_params = None + + if flashinfer_comm is not None: + # Use the largest hidden dimension for workspace setup + max_num_token = _FI_MAX_SIZES.get(world_size) // ( + args.hidden_dim * world_size * 2 + ) + + ipc_handles, workspace_tensor = setup_flashinfer_workspace( + world_size, rank, args.hidden_dim, max_num_token + ) + + if workspace_tensor is not None: + allreduce_params = FlashInferFusedAllReduceParams( + rank=rank, + world_size=world_size, + max_token_num=max_num_token, + ) + + # Collect all results for markdown export + all_results = [] + + try: + # Run benchmarks + for seq_len, dtype, use_residual in configs: + if rank == 0: + logger.info( + "\nTesting: seq_len=%s, hidden_dim=%s, dtype=%s, residual=%s", + seq_len, + args.hidden_dim, + dtype, + use_residual, + ) + + results = run_benchmarks( + seq_len, + args.hidden_dim, + dtype, + use_residual, + allreduce_params, + quant_mode=quant_mode, + disable_oneshot=args.disable_oneshot, + ) + + # Store results for markdown export + if rank == 0: + all_results.append( + { + "seq_len": seq_len, + "hidden_dim": args.hidden_dim, + "dtype": str(dtype).replace("torch.", ""), + "use_residual": use_residual, + "quant_mode": quant_mode, + "results": results, + } + ) + + print_results( + results, + seq_len, + args.hidden_dim, + dtype, + use_residual, + quant_mode, + ) + + # Save results to markdown file + if args.output_file and rank == 0: + save_results_to_file(all_results, world_size, args, rank) + + finally: + # Cleanup + if ipc_handles is not None: + cleanup_flashinfer_workspace(ipc_handles) + + dist.barrier() + + +if __name__ == "__main__": + main() diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index dd31e0db1f59..049375d039e5 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -6,11 +6,14 @@ import torch import vllm.envs as envs +from vllm import _custom_ops as ops from vllm.compilation.collective_fusion import AllReduceFusionPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass +from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig, - ModelConfig, PassConfig, VllmConfig) + ModelConfig, PassConfig, VllmConfig, + get_current_vllm_config, set_current_vllm_config) from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import (init_distributed_environment, initialize_model_parallel) @@ -24,7 +27,19 @@ from .backend import TestBackend +def maybe_dummy_quant(hidden_states): + custom_ops = get_current_vllm_config().compilation_config.custom_ops + if not custom_ops or "+quant_fp8" not in custom_ops: + # Hack: use dynamic fp8 quantization to + # suppress torch.compile optimizations + # that prevent pattern matching + return ops.scaled_fp8_quant(hidden_states) + else: + return hidden_states + + class TestAllReduceRMSNormModel(torch.nn.Module): + pattern_code = 1 def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() @@ -33,10 +48,14 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): self.norm = RMSNorm(hidden_size, eps) def forward(self, hidden_states, residual): - view = hidden_states.reshape(-1, self.hidden_size) - all_reduce = tensor_model_parallel_all_reduce(view) - norm = self.norm(all_reduce) - return norm + # view = hidden_states.reshape(-1, self.hidden_size) + all_reduce = tensor_model_parallel_all_reduce(hidden_states) + + hidden_states = self.norm(all_reduce) + + hidden_states = maybe_dummy_quant(hidden_states) + + return hidden_states def ops_in_model_before(self): return [torch.ops.vllm.all_reduce.default] @@ -46,6 +65,7 @@ def ops_in_model_after(self): class TestAllReduceFusedAddRMSNormModel(torch.nn.Module): + pattern_code = 1 def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() @@ -56,37 +76,54 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): def forward(self, hidden_states, residual): view = hidden_states.reshape(-1, self.hidden_size) all_reduce = tensor_model_parallel_all_reduce(view) - norm, _ = self.norm(all_reduce, residual) - return norm - - def ops_in_model_before(self): - return [torch.ops.vllm.all_reduce.default] + hidden_states, residual = self.norm(all_reduce, residual) + # Hack: use dynamic fp8 quantization to + # suppress torch.compile optimizations + # that prevent pattern matching + hidden_states = maybe_dummy_quant(hidden_states) + return hidden_states, residual def ops_in_model_after(self): return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] + def ops_in_model_before(self): + return [ + torch.ops.vllm.all_reduce.default, + ] + class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module): + pattern_code = 2 def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size self.eps = eps self.norm = RMSNorm(hidden_size, eps) - self.quant_fp8 = QuantFP8(static=True, - group_shape=GroupShape.PER_TENSOR) - self.scale = torch.rand(1, dtype=torch.float32) self.output = torch.empty((token_num, hidden_size), - dtype=torch.float32) + dtype=current_platform.fp8_dtype()) + + def _quant_fp8_wrapper(x, scale): + torch.ops._C.static_scaled_fp8_quant(self.output, x, scale) + return self.output, scale + + vllm_config = get_current_vllm_config() + if "+quant_fp8" in vllm_config.compilation_config.custom_ops: + # Need to use static_scaled_fp8_quant instead of QuantFP8 + # due to failure in TestBackend with copying graph + self.quant_fp8 = _quant_fp8_wrapper + else: + self.quant_fp8 = QuantFP8(static=True, + group_shape=GroupShape.PER_TENSOR) + self.scale = torch.rand(1, dtype=torch.float32) def forward(self, hidden_states, residual): view = hidden_states.reshape(-1, self.hidden_size) all_reduce = tensor_model_parallel_all_reduce(view) norm_output, residual_output = self.norm(all_reduce, residual) - torch.ops._C.static_scaled_fp8_quant(self.output, - norm_output.contiguous(), - self.scale) - return self.output, residual_output + output, _ = self.quant_fp8(norm_output, self.scale) + hidden_states = maybe_dummy_quant(output.to(hidden_states.dtype)) + return hidden_states, residual_output def ops_in_model_after(self): return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] @@ -94,11 +131,11 @@ def ops_in_model_after(self): def ops_in_model_before(self): return [ torch.ops.vllm.all_reduce.default, - torch.ops._C.static_scaled_fp8_quant.default ] class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module): + pattern_code = 3 def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() @@ -131,7 +168,6 @@ def ops_in_model_after(self): def ops_in_model_before(self): return [ torch.ops.vllm.all_reduce.default, - torch.ops._C.scaled_fp4_quant.default ] @@ -142,9 +178,12 @@ def ops_in_model_before(self): TestAllReduceRMSNormModel, TestAllReduceFusedAddRMSNormModel, TestAllReduceFusedAddRMSNormStaticQuantFP8Model, - # TODO: Enable with torch==2.8.0 + # TODO: Enable with flashinfer v0.3.0 # TestAllReduceFusedAddRMSNormStaticQuantFP4Model, ]) +@pytest.mark.parametrize( + "custom_ops", + [[], ["+rms_norm"], ["+quant_fp8"], ["+rms_norm", "+quant_fp8"]]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [8]) @pytest.mark.parametrize("hidden_size", [16]) @@ -157,19 +196,23 @@ def ops_in_model_before(self): reason="flashinfer is not found or flashinfer " "is not compiled with trtllm_allreduce_fusion") def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module, - batch_size: int, seq_len: int, - hidden_size: int, dtype: torch.dtype): + custom_ops: list[str], batch_size: int, + seq_len: int, hidden_size: int, + dtype: torch.dtype): num_processes = 2 if (test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model and not current_platform.has_device_capability(100)): pytest.skip("Skip as nvfp4 is only supported on " "devices with compute capability 10.0 (Blackwell)") + if (test_model != TestAllReduceFusedAddRMSNormStaticQuantFP8Model + and ("+quant_fp8" in custom_ops)): + pytest.skip() def run_torch_spawn(fn, nprocs): torch.multiprocessing.spawn(fn, args=(num_processes, test_model, batch_size, seq_len, hidden_size, - dtype), + dtype, custom_ops), nprocs=nprocs) run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes) @@ -178,7 +221,8 @@ def run_torch_spawn(fn, nprocs): def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, test_model_cls: torch.nn.Module, batch_size: int, seq_len: int, - hidden_size: int, dtype: torch.dtype): + hidden_size: int, dtype: torch.dtype, + custom_ops: list[str]): current_platform.seed_everything(0) device = torch.device(f"cuda:{local_rank}") @@ -198,8 +242,7 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, initialize_model_parallel(tensor_model_parallel_size=world_size) vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - custom_ops=["+rms_norm", "+quant_fp8"])) + level=CompilationLevel.PIECEWISE, custom_ops=custom_ops)) vllm_config.compilation_config.pass_config = PassConfig( enable_fi_allreduce_fusion=True, enable_noop=True) vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) @@ -211,22 +254,30 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, trust_remote_code=True, dtype=dtype, seed=42) - - all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) - noop_pass = NoOpEliminationPass(vllm_config) - func_pass = FixFunctionalizationPass(vllm_config) - - backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass) - - token_num = batch_size * seq_len - model = test_model_cls(hidden_size, token_num) - - hidden_states = torch.randn((token_num, hidden_size), requires_grad=False) - residual = torch.randn((token_num, hidden_size), requires_grad=False) - - compiled_model = torch.compile(model, backend=backend) - compiled_model(hidden_states, residual) - - backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) - backend.check_after_ops(model.ops_in_model_after()) - del all_reduce_fusion_pass + with set_current_vllm_config(vllm_config): + all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) + noop_pass = NoOpEliminationPass(vllm_config) + func_pass = FixFunctionalizationPass(vllm_config) + + backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass) + + token_num = batch_size * seq_len + model = test_model_cls(hidden_size, token_num) + + hidden_states = torch.randn((token_num, hidden_size), + requires_grad=False) + residual = torch.randn((token_num, hidden_size), + dtype=torch.float32, + requires_grad=False) + + compiled_model = torch.compile(model, backend=backend) + compiled_model(hidden_states, residual) + + backend.check_before_ops(model.ops_in_model_before(), + fully_replaced=False) + backend.check_after_ops(model.ops_in_model_after()) + for node in find_op_nodes( + torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default, + backend.graph_post_pass): + assert ( + node.kwargs.get("pattern_code") == test_model_cls.pattern_code) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 71274420c342..7024cd55bf83 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -10,12 +10,14 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from torch.distributed._symmetric_memory import enable_symm_mem_for_group -import vllm.envs as envs -from vllm.config import VllmConfig +from vllm.config import VllmConfig, set_current_vllm_config from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op @@ -396,30 +398,21 @@ def __call__(self, graph: fx.Graph): _FI_WORKSPACE_TENSOR = None MiB = 1024 * 1024 - # Max size of the input tensor per world size - # to use flashinfer fused allreduce - _FI_MAX_SIZES = { - 2: 64 * MiB, # 64MB - 4: MiB, # 1MB - 6: MiB // 2, # 512KB - 8: MiB // 2, # 512KB + # Max size of the input tensor per world size per device capability + # to use flashinfer one shot fused allreduce + _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES = { + "9.0": { + 2: 32 * MiB, # 32MB + 4: 2 * MiB, # 2MB + 8: 1 * MiB, # 1MB + }, + "10.0": { + 2: 32 * MiB, # 32MB + 4: 4 * MiB, # 4MB + 8: 1 * MiB, # 1MB + }, } - try: - _FI_MAX_SIZES.update({ - int(k): int(float(v) * MiB) - for k, v in - envs.VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB.items() - }) - except Exception as e: - raise ValueError( - "Failed to parse VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB: " - + str(e)) from e - - # opt for a more conservative default value - # when world size is not in _FI_MAX_SIZES - _DEFAULT_FI_MAX_SIZE = MiB // 2 - def call_trtllm_fused_allreduce_norm( allreduce_in: torch.Tensor, residual: torch.Tensor, @@ -432,7 +425,6 @@ def call_trtllm_fused_allreduce_norm( fp32_acc: bool, max_token_num: int, pattern_code: int, - fuse_rms_quant: bool, norm_out: Optional[torch.Tensor] = None, quant_out: Optional[torch.Tensor] = None, scale_out: Optional[torch.Tensor] = None, @@ -441,12 +433,20 @@ def call_trtllm_fused_allreduce_norm( num_tokens, hidden_size = allreduce_in.shape element_size = allreduce_in.element_size() current_tensor_size = num_tokens * hidden_size * element_size - max_fusion_size = max_token_num * hidden_size * element_size - use_flashinfer = current_tensor_size <= min( - _FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE), - max_fusion_size, - ) - if use_flashinfer: + max_tensor_size = max_token_num * hidden_size * element_size + + if current_tensor_size <= max_tensor_size: + device_capability = current_platform.get_device_capability( + ).as_version_str() + # Get one shot input size limit for the current world size + # for the current device capability + max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES. \ + get(device_capability, {}). \ + get(world_size, None) + # Use one shot if no max size is specified + use_oneshot = max_one_shot_size is None or \ + current_tensor_size <= max_one_shot_size + assert (_FI_WORKSPACE_TENSOR is not None ), "Flashinfer must be enabled when using flashinfer" if norm_out is None: @@ -472,7 +472,7 @@ def call_trtllm_fused_allreduce_norm( hidden_dim=allreduce_in.shape[-1], workspace_ptrs=_FI_WORKSPACE_TENSOR, launch_with_pdl=launch_with_pdl, - use_oneshot=True, + use_oneshot=use_oneshot, trigger_completion_at_end=trigger_completion_at_end, fp32_acc=fp32_acc, pattern_code=pattern_code, @@ -486,8 +486,7 @@ def call_trtllm_fused_allreduce_norm( ) else: allreduce_out = tensor_model_parallel_all_reduce(allreduce_in) - if (scale_factor is not None and scale_out is None - and fuse_rms_quant): + if (scale_factor is not None and scale_out is None): # Do fused rms norm static fp8 quant fused op if norm_out is None: torch.ops._C.fused_add_rms_norm_static_fp8_quant( @@ -530,7 +529,6 @@ def call_trtllm_fused_allreduce_norm_fake( fp32_acc: bool, max_token_num: int, pattern_code: int, - fuse_rms_quant: bool, norm_out: Optional[torch.Tensor] = None, quant_out: Optional[torch.Tensor] = None, scale_out: Optional[torch.Tensor] = None, @@ -563,7 +561,6 @@ def __init__( world_size: int, use_fp32_lamport: bool = False, max_token_num: int = 1024, - fuse_rms_quant: bool = False, ): self.rank = rank self.world_size = world_size @@ -571,9 +568,7 @@ def __init__( self.trigger_completion_at_end = True self.launch_with_pdl = True self.fp32_acc = True - self.use_oneshot = False self.max_token_num = max_token_num - self.fuse_rms_quant = fuse_rms_quant def get_trtllm_fused_allreduce_kwargs(self): return { @@ -583,35 +578,113 @@ def get_trtllm_fused_allreduce_kwargs(self): "trigger_completion_at_end": self.trigger_completion_at_end, "fp32_acc": self.fp32_acc, "max_token_num": self.max_token_num, - "fuse_rms_quant": self.fuse_rms_quant, } -class AllReduceRMSNormPattern(BasePattern): +def rms_norm_native(input: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + residual: Optional[torch.Tensor] = None): + orig_dtype = input.dtype + input = input.to(torch.float32) + if residual is not None: + input = input + residual.to(torch.float32) + # residual = input.to(orig_dtype) + residual = input + + variance = input.pow(2).mean(dim=-1, keepdim=True) + + input = input * torch.rsqrt(variance + epsilon) + input = input.to(orig_dtype) + input = input * weight + if residual is None: + return input + else: + return input, residual + + +class AllReduceRMSNormNativePattern(BasePattern): """ - This pattern replaces the allreduce + rms norm (without residual) + This pattern replaces allreduce + torch native rms norm (without residual) with fused flashinfer implementation. Applies to allreduce + rmsnorm before attn in the first Transformer block. """ def __init__( - self, - epsilon: float, - dtype: torch.dtype, - device: str, - allreduce_params: FlashInferFusedAllReduceParams, + self, + epsilon: float, + dtype: torch.dtype, + device: str, + allreduce_params: FlashInferFusedAllReduceParams, + is_custom_ops: tuple[bool, bool] = (False, False), ): super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params + self.is_custom_rms_norm = is_custom_ops[0] def get_inputs(self): - input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) - rms_result = torch.empty([1, 8, 4], - device=self.device, - dtype=self.dtype) - weight = torch.empty([4], device=self.device, dtype=self.dtype) + input = torch.empty([4, 4], device=self.device, dtype=self.dtype) + weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) + return [input, weight] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern(input: torch.Tensor, weight: torch.Tensor): + allreduce_output = tensor_model_parallel_all_reduce(input) + rms_output = rms_norm_native(allreduce_output, weight, + self.epsilon) + # rms_result, allreduce_output + return rms_output, allreduce_output + + def replacement(input: torch.Tensor, weight: torch.Tensor): + residual = torch.zeros_like(input) + rms_result = torch.empty_like(input) + allreduce = auto_functionalized( + flashinfer_trtllm_fused_allreduce_norm, + allreduce_in=input, + residual=residual, + norm_out=rms_result, + quant_out=None, + scale_out=None, + rms_gamma=weight, + rms_eps=self.epsilon, + pattern_code=flashinfer_comm.AllReduceFusionPattern. + kARResidualRMSNorm, + **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + # rms_result, allreduce_in + return allreduce[3], allreduce[1] + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class AllReduceRMSNormCustomOpPattern(BasePattern): + """ + This pattern replaces the allreduce + custom op rms norm (without residual) + with fused flashinfer implementation. + Applies to allreduce + rmsnorm before attn in the first Transformer block. + """ + + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str, + allreduce_params: FlashInferFusedAllReduceParams, + is_custom_ops: tuple[bool, bool] = (False, False), + ): + super().__init__(dtype, device) + self.epsilon = epsilon + self.allreduce_params = allreduce_params + self.is_custom_rms_norm = is_custom_ops[0] + + def get_inputs(self): + input = torch.empty([4, 4], device=self.device, dtype=self.dtype) + weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) + rms_result = torch.empty([4, 4], device=self.device, dtype=self.dtype) return [input, rms_result, weight] def register(self, pm_pass: PatternMatcherPass): @@ -626,7 +699,7 @@ def pattern(input: torch.Tensor, rms_result: torch.Tensor, weight=weight, epsilon=self.epsilon, ) - # rms_result, allreduce_output + return rms[1], allreduce_output def replacement(input: torch.Tensor, rms_result: torch.Tensor, @@ -652,9 +725,69 @@ def replacement(input: torch.Tensor, rms_result: torch.Tensor, pm.fwd_only, pm_pass) -class AllReduceFusedAddRMSNormPattern(BasePattern): +class AllReduceFusedAddRMSNormNativePattern(BasePattern): + """ + This pattern replaces the allreduce + torch native rms norm (with residual) + with fused flashinfer implementation. + Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn. + """ + + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str, + allreduce_params: FlashInferFusedAllReduceParams, + ): + super().__init__(dtype, device) + self.epsilon = epsilon + self.allreduce_params = allreduce_params + + def get_inputs(self): + input = torch.empty([4, 4], device=self.device, dtype=self.dtype) + residual = torch.empty([4, 4], device=self.device, dtype=torch.float32) + weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) + return [ + residual, + input, + weight, + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern(residual: torch.Tensor, input: torch.Tensor, + weight: torch.Tensor): + allreduce_output = tensor_model_parallel_all_reduce(input) + rms_output, rms_residual = rms_norm_native(allreduce_output, + weight, self.epsilon, + residual) + return rms_output, rms_residual + + def replacement(residual: torch.Tensor, input: torch.Tensor, + weight: torch.Tensor): + allreduce = auto_functionalized( + flashinfer_trtllm_fused_allreduce_norm, + allreduce_in=input, + residual=residual.to(self.dtype), + norm_out=None, + quant_out=None, + scale_out=None, + rms_gamma=weight, + rms_eps=self.epsilon, + pattern_code=flashinfer_comm.AllReduceFusionPattern. + kARResidualRMSNorm, + **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + # allreduce_in, residual + return allreduce[1], allreduce[2] + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class AllReduceFusedAddRMSNormCustomOpPattern(BasePattern): """ - This pattern replaces the allreduce + rms norm (with residual) + This pattern replaces the allreduce + custom op rms norm (with residual) with fused flashinfer implementation. Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn. """ @@ -717,63 +850,53 @@ def replacement(residual: torch.Tensor, input: torch.Tensor, pm.fwd_only, pm_pass) -class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern): +class AllReduceFusedRMSNormNativeStaticQuantFP8NativePattern(BasePattern): """ - This pattern replaces the allreduce + rms norm (without residual) - + static fp8 quant with fused flashinfer implementation. + This pattern replaces allreduce + torch native rms norm (without residual) + + native static fp8 quant with fused flashinfer implementation. Applies to allreduce + rmsnorm + quant before attn in the first Transformer block. """ - def __init__(self, epsilon: float, dtype: torch.dtype, device: str, - allreduce_params: FlashInferFusedAllReduceParams): + def __init__(self, + epsilon: float, + dtype: torch.dtype, + device: str, + allreduce_params: FlashInferFusedAllReduceParams, + is_custom_ops: tuple[bool, bool] = (False, False)): super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params - self.quant_dtype = torch.float8_e4m3fn + self.quant_fp8 = QuantFP8(static=True, + group_shape=GroupShape.PER_TENSOR) + self.quant_dtype = current_platform.fp8_dtype() + self.is_custom_rms_norm = is_custom_ops[0] + self.is_custom_fp8 = is_custom_ops[1] - def register(self, pm_pass: PatternMatcherPass): + def get_inputs(self): + input = torch.zeros([4, 4], device=self.device, dtype=self.dtype) + weight = torch.empty([4], device=self.device, dtype=self.dtype) + scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) + return [input, weight, scale] - def get_inputs(): - input = torch.zeros([1, 8, 4], - device=self.device, - dtype=self.dtype) - rmsnorm_result = torch.empty([1, 8, 4], - device=self.device, - dtype=self.dtype) - quant_result = torch.empty([1, 8, 4], - device=self.device, - dtype=self.quant_dtype) - weight = torch.empty([4], device=self.device, dtype=self.dtype) - scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) - return [input, rmsnorm_result, quant_result, weight, scale] + def register(self, pm_pass: PatternMatcherPass): def pattern( input: torch.Tensor, - rmsnorm_result: torch.Tensor, - quant_result: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): all_reduce = tensor_model_parallel_all_reduce(input) - rmsnorm_out_tuple = auto_functionalized(RMS_OP, - result=rmsnorm_result, - input=all_reduce, - weight=weight, - epsilon=self.epsilon) - - quant_out_tuple = auto_functionalized(STATIC_FP8_QUANT_OP, - result=quant_result, - input=rmsnorm_out_tuple[1], - scale=scale) - + rmsnorm_result = rms_norm_native(all_reduce, weight, self.epsilon) + quant_out, _ = self.quant_fp8(rmsnorm_result, scale=scale) # quant_out, allreduce_output - return quant_out_tuple[1], all_reduce + return quant_out, all_reduce - def replacement(input: torch.Tensor, result_rms: torch.Tensor, - quant_result: torch.Tensor, weight: torch.Tensor, + def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): residual = torch.zeros_like(input) + result_rms = torch.empty_like(input) + quant_result = torch.empty_like(input, dtype=self.quant_dtype) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, @@ -784,7 +907,7 @@ def replacement(input: torch.Tensor, result_rms: torch.Tensor, rms_gamma=weight, rms_eps=self.epsilon, pattern_code=flashinfer_comm.AllReduceFusionPattern. - kARResidualRMSNormFP8Quant, # we don't use norm_out afterwards + kARResidualRMSNormFP8Quant, scale_factor=scale, **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), ) @@ -792,102 +915,90 @@ def replacement(input: torch.Tensor, result_rms: torch.Tensor, # quant_out, allreduce_output return allreduce[4], allreduce[1] - pm.register_replacement(pattern, replacement, get_inputs(), + pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) -class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern): +class AllReduceFusedRMSNormCustomOpStaticQuantFP8NativePattern(BasePattern): """ - This pattern replaces the allreduce + rms norm (with residual) - + static fp8 quant with fused flashinfer implementation. - Applies to o_proj + rmsnorm after attn + quant and - mlp + rmsnorm + quant before attn. + This pattern replaces the allreduce + custom op rms norm (without residual) + + native static fp8 quant with fused flashinfer implementation. + Applies to allreduce + rmsnorm + quant before attn + in the first Transformer block. """ - def __init__(self, epsilon: float, dtype: torch.dtype, device: str, - allreduce_params: FlashInferFusedAllReduceParams): + def __init__(self, + epsilon: float, + dtype: torch.dtype, + device: str, + allreduce_params: FlashInferFusedAllReduceParams, + is_custom_ops: tuple[bool, bool] = (False, False)): super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params - self.quant_dtype = torch.float8_e4m3fn - - def register(self, pm_pass: PatternMatcherPass): + self.quant_fp8 = QuantFP8(static=True, + group_shape=GroupShape.PER_TENSOR) + self.quant_dtype = current_platform.fp8_dtype() - def get_inputs(): - input = torch.empty([4, 4], device=self.device, dtype=self.dtype) + def get_inputs(self): + input = torch.zeros([4, 4], device=self.device, dtype=self.dtype) + weight = torch.empty([4], device=self.device, dtype=self.dtype) + scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) + rmsnorm_result = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) + return [input, weight, scale, rmsnorm_result] - residual = torch.empty([4, 4], - device=self.device, - dtype=self.dtype) - weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) - quant_result = torch.empty([4, 4], - device=self.device, - dtype=self.quant_dtype) - scale = torch.empty([1, 1], - device=self.device, - dtype=torch.float32) - - return [ - quant_result, - residual, - input, - weight, - scale, - ] + def register(self, pm_pass: PatternMatcherPass): + # rmsnorm custom op def pattern( - quant_result: torch.Tensor, - residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, + rmsnorm_result: torch.Tensor, ): - allreduce_output = tensor_model_parallel_all_reduce(input) - - fused_add_rmsnorm_out_tuple = \ - auto_functionalized( - RMS_ADD_OP, - input=allreduce_output, - residual=residual, - weight=weight, - epsilon=self.epsilon) - quant_out_tuple = auto_functionalized( - STATIC_FP8_QUANT_OP, - result=quant_result, - input=fused_add_rmsnorm_out_tuple[1], - scale=scale) + all_reduce = tensor_model_parallel_all_reduce(input) + rmsnorm_out_tuple = auto_functionalized(RMS_OP, + result=rmsnorm_result, + input=all_reduce, + weight=weight, + epsilon=self.epsilon) + quant_out, _ = self.quant_fp8(rmsnorm_out_tuple[1], scale=scale) # quant_out, allreduce_output - return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[2] + return quant_out, all_reduce - def replacement(quant_result: torch.Tensor, residual: torch.Tensor, - input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): + def replacement(input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor, rmsnorm_result: torch.Tensor): + residual = torch.zeros_like(input) + quant_result = torch.empty_like(input, dtype=self.quant_dtype) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, residual=residual, - norm_out=None, + norm_out=rmsnorm_result, quant_out=quant_result, scale_out=None, rms_gamma=weight, rms_eps=self.epsilon, pattern_code=flashinfer_comm.AllReduceFusionPattern. - kARResidualRMSNormFP8Quant, # we don't use norm_out afterwards + kARResidualRMSNormFP8Quant, scale_factor=scale, **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), ) - # # quant_out, rms_norm_residual - return allreduce[4], allreduce[2] - pm.register_replacement(pattern, replacement, get_inputs(), + # quant_out, allreduce_output + return allreduce[4], allreduce[1] + + pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) -class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern): +class AllReduceFusedRMSNormNativeStaticQuantFP8CustomOpPattern(BasePattern): """ - This pattern replaces the allreduce + rms norm (without residual) - + static nvfp4 quant with fused flashinfer implementation. + This pattern replaces allreduce + torch native rms norm (without residual) + + custom op static fp8 quant with fused flashinfer implementation. Applies to allreduce + rmsnorm + quant before attn in the first Transformer block. """ @@ -897,91 +1008,71 @@ def __init__(self, epsilon: float, dtype: torch.dtype, device: str, super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params + self.quant_fp8 = QuantFP8(static=True, + group_shape=GroupShape.PER_TENSOR) + self.quant_dtype = current_platform.fp8_dtype() - def register(self, pm_pass: PatternMatcherPass): - - def get_inputs(): - input = torch.empty([1, 16, 16], - device=self.device, - dtype=self.dtype) + def get_inputs(self): + input = torch.zeros([4, 4], device=self.device, dtype=self.dtype) + weight = torch.empty([4], device=self.device, dtype=self.dtype) + scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) + quant_result = torch.empty([4, 4], + device=self.device, + dtype=self.quant_dtype) + return [input, weight, scale, quant_result] - rmsnorm_result = torch.empty([1, 16, 16], - device=self.device, - dtype=self.dtype) - quant_result = torch.empty((16, 8), - device=self.device, - dtype=torch.uint8) - input_global_scale = torch.empty([1, 1], - device=self.device, - dtype=torch.float32) - weight = torch.empty([16], device=self.device, dtype=self.dtype) - output_scale = torch.empty([128, 4], - device=self.device, - dtype=torch.int32) - - return [ - input, rmsnorm_result, quant_result, weight, - input_global_scale, output_scale - ] + def register(self, pm_pass: PatternMatcherPass): def pattern( input: torch.Tensor, - rmsnorm_result: torch.Tensor, - quant_result: torch.Tensor, weight: torch.Tensor, - input_global_scale: torch.Tensor, - output_scale: torch.Tensor, + scale: torch.Tensor, + quant_result: torch.Tensor, ): all_reduce = tensor_model_parallel_all_reduce(input) - rmsnorm_out_tuple = auto_functionalized(RMS_OP, - result=rmsnorm_result, - input=all_reduce, - weight=weight, - epsilon=self.epsilon) - - quant_out_tuple = auto_functionalized( - STATIC_FP4_QUANT_OP, - output=quant_result, - input=rmsnorm_out_tuple[1], - output_scale=output_scale, - input_scale=input_global_scale) - - # quant_out, allreduce_output, output_scale - return quant_out_tuple[1], all_reduce, quant_out_tuple[2] + rmsnorm_out = rms_norm_native(all_reduce, weight, self.epsilon) + quant_out_tuple = auto_functionalized(STATIC_FP8_QUANT_OP, + result=quant_result, + input=rmsnorm_out, + scale=scale) + # quant_out, allreduce_output + return quant_out_tuple[1], all_reduce - def replacement(input: torch.Tensor, result_rms: torch.Tensor, - quant_result: torch.Tensor, weight: torch.Tensor, - input_global_scale: torch.Tensor, - output_scale: torch.Tensor): + def replacement( + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + quant_result: torch.Tensor, + ): residual = torch.zeros_like(input) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, residual=residual, - norm_out=result_rms, + norm_out=None, quant_out=quant_result, - scale_out=output_scale, + scale_out=None, rms_gamma=weight, rms_eps=self.epsilon, pattern_code=flashinfer_comm.AllReduceFusionPattern. - kARResidualRMSNormFP4Quant, # we don't use norm_out afterwards - scale_factor=input_global_scale, + kARResidualRMSNormFP8Quant, + scale_factor=scale, **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), ) - # quant_out, allreduce_output, output_scale - return allreduce[4], allreduce[1], allreduce[5] + # quant_out, allreduce_output + return allreduce[4], allreduce[1] - pm.register_replacement(pattern, replacement, get_inputs(), + pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) -class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern): +class AllReduceFusedRMSNormCustomOpStaticQuantFP8CustomOpPattern(BasePattern): """ - This pattern replaces the allreduce + rms norm (with residual) - + static nvfp4 quant with fused flashinfer implementation. - Applies to o_proj + rmsnorm after attn + quant and - mlp + rmsnorm + quant before attn. + This pattern replaces the allreduce + custom op rms norm (without residual) + + custom op static fp8 quant with fused flashinfer implementation. + Applies to allreduce + rmsnorm + quant before attn + in the first Transformer block. """ def __init__(self, epsilon: float, dtype: torch.dtype, device: str, @@ -989,57 +1080,669 @@ def __init__(self, epsilon: float, dtype: torch.dtype, device: str, super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params + self.quant_fp8 = QuantFP8(static=True, + group_shape=GroupShape.PER_TENSOR) + self.quant_dtype = current_platform.fp8_dtype() - def register(self, pm_pass: PatternMatcherPass): + def get_inputs(self): + input = torch.zeros([4, 4], device=self.device, dtype=self.dtype) + weight = torch.empty([4], device=self.device, dtype=self.dtype) + scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) - def get_inputs(): - input = torch.empty([16, 16], device=self.device, dtype=self.dtype) + rmsnorm_result = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) - residual = torch.empty([16, 16], + quant_result = torch.empty([4, 4], device=self.device, - dtype=self.dtype) - weight = torch.empty([16, 16], - device=self.device, - dtype=self.dtype) - quant_result = torch.empty((16, 8), - device=self.device, - dtype=torch.uint8) - input_global_scale = torch.empty([1, 1], - device=self.device, - dtype=torch.float32) - output_scale = torch.empty([128, 4], - device=self.device, - dtype=torch.int32) - - return [ - quant_result, - residual, - input, - output_scale, - weight, - input_global_scale, - ] - - def pattern(quant_result: torch.Tensor, residual: torch.Tensor, - input: torch.Tensor, output_scale: torch.Tensor, - weight: torch.Tensor, input_global_scale: torch.Tensor): - allreduce_output = tensor_model_parallel_all_reduce(input) + dtype=self.quant_dtype) + return [input, weight, scale, rmsnorm_result, quant_result] - fused_add_rmsnorm_out_tuple = \ - auto_functionalized( - RMS_ADD_OP, - input=allreduce_output, - residual=residual, - weight=weight, - epsilon=self.epsilon) - quant_out_tuple = auto_functionalized( - STATIC_FP4_QUANT_OP, - output=quant_result, + def register(self, pm_pass: PatternMatcherPass): + # rmsnorm custom op + def pattern( + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + rmsnorm_result: torch.Tensor, + quant_result: torch.Tensor, + ): + all_reduce = tensor_model_parallel_all_reduce(input) + rmsnorm_out_tuple = auto_functionalized(RMS_OP, + result=rmsnorm_result, + input=all_reduce, + weight=weight, + epsilon=self.epsilon) + + quant_out_tuple = auto_functionalized(STATIC_FP8_QUANT_OP, + result=quant_result, + input=rmsnorm_out_tuple[1], + scale=scale) + + # quant_out, allreduce_output + return quant_out_tuple[1], all_reduce + + def replacement(input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor, rmsnorm_result: torch.Tensor, + quant_result: torch.Tensor): + residual = torch.zeros_like(input) + allreduce = auto_functionalized( + flashinfer_trtllm_fused_allreduce_norm, + allreduce_in=input, + residual=residual, + norm_out=rmsnorm_result, + quant_out=quant_result, + scale_out=None, + rms_gamma=weight, + rms_eps=self.epsilon, + pattern_code=flashinfer_comm.AllReduceFusionPattern. + kARResidualRMSNormFP8Quant, + scale_factor=scale, + **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + # quant_out, allreduce_output + return allreduce[4], allreduce[1] + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class AllReduceFusedAddRMSNormNativeStaticQuantFP8NativePattern(BasePattern): + """ + This pattern replaces the allreduce + torch native rms norm (with residual) + + torch native static fp8 quant with fused flashinfer implementation. + Applies to o_proj + rmsnorm after attn + quant and + mlp + rmsnorm + quant before attn. + """ + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + allreduce_params: FlashInferFusedAllReduceParams): + super().__init__(dtype, device) + self.epsilon = epsilon + self.allreduce_params = allreduce_params + self.quant_fp8 = QuantFP8(static=True, + group_shape=GroupShape.PER_TENSOR) + self.quant_dtype = current_platform.fp8_dtype() + + def get_inputs(self): + input = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=torch.float32) + weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) + scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) + return [residual, input, weight, scale] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + residual: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + allreduce_output = tensor_model_parallel_all_reduce(input) + rmsnorm_out, rmsnorm_residual = rms_norm_native( + allreduce_output, weight, self.epsilon, residual) + quant_out, _ = self.quant_fp8(rmsnorm_out, scale=scale) + # quant_out, allreduce_output + return quant_out, rmsnorm_residual + + def replacement(residual: torch.Tensor, input: torch.Tensor, + weight: torch.Tensor, scale: torch.Tensor): + quant_result = torch.empty_like(input, dtype=self.quant_dtype) + + allreduce = auto_functionalized( + flashinfer_trtllm_fused_allreduce_norm, + allreduce_in=input, + residual=residual.to(self.dtype), + norm_out=None, + quant_out=quant_result, + scale_out=None, + rms_gamma=weight, + rms_eps=self.epsilon, + pattern_code=flashinfer_comm.AllReduceFusionPattern. + kARResidualRMSNormFP8Quant, + scale_factor=scale, + **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + # quant_out, rms_norm_residual + return allreduce[4], allreduce[2] + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class AllReduceFusedAddRMSNormCustomOpStaticQuantFP8NativePattern(BasePattern): + """ + This pattern replaces the allreduce + custom op rms norm (with residual) + + torch native static fp8 quant with fused flashinfer implementation. + Applies to o_proj + rmsnorm after attn + quant and + mlp + rmsnorm + quant before attn. + """ + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + allreduce_params: FlashInferFusedAllReduceParams): + super().__init__(dtype, device) + self.epsilon = epsilon + self.allreduce_params = allreduce_params + self.quant_fp8 = QuantFP8(static=True, + group_shape=GroupShape.PER_TENSOR) + self.quant_dtype = current_platform.fp8_dtype() + + def get_inputs(self): + input = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) + scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) + return [residual, input, weight, scale] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + residual: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + allreduce_output = tensor_model_parallel_all_reduce(input) + + fused_add_rmsnorm_out_tuple = \ + auto_functionalized( + RMS_ADD_OP, + input=allreduce_output, + residual=residual, + weight=weight, + epsilon=self.epsilon) + quant_out, _ = self.quant_fp8(fused_add_rmsnorm_out_tuple[1], + scale=scale) + + # quant_out, allreduce_output + return quant_out, fused_add_rmsnorm_out_tuple[2] + + def replacement(residual: torch.Tensor, input: torch.Tensor, + weight: torch.Tensor, scale: torch.Tensor): + quant_result = torch.empty_like(input, dtype=self.quant_dtype) + + allreduce = auto_functionalized( + flashinfer_trtllm_fused_allreduce_norm, + allreduce_in=input, + residual=residual, + norm_out=None, + quant_out=quant_result, + scale_out=None, + rms_gamma=weight, + rms_eps=self.epsilon, + pattern_code=flashinfer_comm.AllReduceFusionPattern. + kARResidualRMSNormFP8Quant, + scale_factor=scale, + **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + # quant_out, rms_norm_residual + return allreduce[4], allreduce[2] + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class AllReduceFusedAddRMSNormNativeStaticQuantFP8CustomOpPattern(BasePattern): + """ + This pattern replaces the allreduce + torch native rms norm (with residual) + + custom op static fp8 quant with fused flashinfer implementation. + Applies to o_proj + rmsnorm after attn + quant and + mlp + rmsnorm + quant before attn. + """ + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + allreduce_params: FlashInferFusedAllReduceParams): + super().__init__(dtype, device) + self.epsilon = epsilon + self.allreduce_params = allreduce_params + self.quant_fp8 = QuantFP8(static=True, + group_shape=GroupShape.PER_TENSOR) + self.quant_dtype = current_platform.fp8_dtype() + + def get_inputs(self): + input = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=torch.float32) + weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) + scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) + quant_result = torch.empty([4, 4], + device=self.device, + dtype=self.quant_dtype) + return [residual, input, weight, scale, quant_result] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + residual: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + quant_result: torch.Tensor, + ): + allreduce_output = tensor_model_parallel_all_reduce(input) + rmsnorm_out, rmsnorm_residual = rms_norm_native( + allreduce_output, weight, self.epsilon, residual) + quant_out_tuple = auto_functionalized(STATIC_FP8_QUANT_OP, + result=quant_result, + input=rmsnorm_out, + scale=scale) + # quant_out, allreduce_output + return quant_out_tuple[1], rmsnorm_residual + + def replacement(residual: torch.Tensor, input: torch.Tensor, + weight: torch.Tensor, scale: torch.Tensor, + quant_result: torch.Tensor): + allreduce = auto_functionalized( + flashinfer_trtllm_fused_allreduce_norm, + allreduce_in=input, + residual=residual.to(self.dtype), + norm_out=None, + quant_out=quant_result, + scale_out=None, + rms_gamma=weight, + rms_eps=self.epsilon, + pattern_code=flashinfer_comm.AllReduceFusionPattern. + kARResidualRMSNormFP8Quant, + scale_factor=scale, + **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + # quant_out, rms_norm_residual + return allreduce[4], allreduce[2] + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class AllReduceFusedAddRMSNormCustomOpStaticQuantFP8CustomOpPattern( + BasePattern): + """ + This pattern replaces the allreduce + custom op rms norm (with residual) + + custom op static fp8 quant with fused flashinfer implementation. + Applies to o_proj + rmsnorm after attn + quant and + mlp + rmsnorm + quant before attn. + """ + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + allreduce_params: FlashInferFusedAllReduceParams): + super().__init__(dtype, device) + self.epsilon = epsilon + self.allreduce_params = allreduce_params + self.quant_fp8 = QuantFP8(static=True, + group_shape=GroupShape.PER_TENSOR) + self.quant_dtype = current_platform.fp8_dtype() + + def get_inputs(self): + input = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) + scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) + quant_result = torch.empty([4, 4], + device=self.device, + dtype=self.quant_dtype) + return [residual, input, weight, scale, quant_result] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + residual: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + quant_result: torch.Tensor, + ): + allreduce_output = tensor_model_parallel_all_reduce(input) + + fused_add_rmsnorm_out_tuple = \ + auto_functionalized( + RMS_ADD_OP, + input=allreduce_output, + residual=residual, + weight=weight, + epsilon=self.epsilon) + quant_out_tuple = auto_functionalized( + STATIC_FP8_QUANT_OP, + result=quant_result, input=fused_add_rmsnorm_out_tuple[1], + scale=scale) + + # quant_out, allreduce_output + return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[2] + + def replacement(residual: torch.Tensor, input: torch.Tensor, + weight: torch.Tensor, scale: torch.Tensor, + quant_result: torch.Tensor): + allreduce = auto_functionalized( + flashinfer_trtllm_fused_allreduce_norm, + allreduce_in=input, + residual=residual, + norm_out=None, + quant_out=quant_result, + scale_out=None, + rms_gamma=weight, + rms_eps=self.epsilon, + pattern_code=flashinfer_comm.AllReduceFusionPattern. + kARResidualRMSNormFP8Quant, + scale_factor=scale, + **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + # quant_out, rms_norm_residual + return allreduce[4], allreduce[2] + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class AllReduceFusedRMSNormNativeStaticQuantNVFP4Pattern(BasePattern): + """ + This pattern replaces allreduce + torch native rms norm (without residual) + + static nvfp4 quant with fused flashinfer implementation. + Applies to allreduce + rmsnorm + quant before attn + in the first Transformer block. + """ + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + allreduce_params: FlashInferFusedAllReduceParams): + super().__init__(dtype, device) + self.epsilon = epsilon + self.allreduce_params = allreduce_params + + def get_inputs(self): + input = torch.empty([16, 16], device=self.device, dtype=self.dtype) + + quant_result = torch.empty((16, 8), + device=self.device, + dtype=torch.uint8) + input_global_scale = torch.empty([1, 1], + device=self.device, + dtype=torch.float32) + weight = torch.empty([16], device=self.device, dtype=self.dtype) + output_scale = torch.empty([128, 4], + device=self.device, + dtype=torch.int32) + return [input, quant_result, weight, input_global_scale, output_scale] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + input: torch.Tensor, + quant_result: torch.Tensor, + weight: torch.Tensor, + input_global_scale: torch.Tensor, + output_scale: torch.Tensor, + ): + all_reduce = tensor_model_parallel_all_reduce(input) + rmsnorm_out = rms_norm_native(all_reduce, weight, self.epsilon) + quant_out_tuple = auto_functionalized( + STATIC_FP4_QUANT_OP, + output=quant_result, + input=rmsnorm_out, + output_scale=output_scale, + input_scale=input_global_scale) + return quant_out_tuple[1], all_reduce, quant_out_tuple[2] + + def replacement(input: torch.Tensor, quant_result: torch.Tensor, + weight: torch.Tensor, input_global_scale: torch.Tensor, + output_scale: torch.Tensor): + residual = torch.zeros_like(input) + result_rms = torch.zeros_like(input) + allreduce = auto_functionalized( + flashinfer_trtllm_fused_allreduce_norm, + allreduce_in=input, + residual=residual, + norm_out=result_rms, + quant_out=quant_result, + scale_out=output_scale, + rms_gamma=weight, + rms_eps=self.epsilon, + pattern_code=flashinfer_comm.AllReduceFusionPattern. + kARResidualRMSNormFP4Quant, + scale_factor=input_global_scale, + **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + # quant_out, allreduce_output, output_scale + return allreduce[4], allreduce[1], allreduce[5] + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class AllReduceFusedRMSNormCustomOpStaticQuantNVFP4Pattern(BasePattern): + """ + This pattern replaces the allreduce + custom op rms norm (without residual) + + static nvfp4 quant with fused flashinfer implementation. + Applies to allreduce + rmsnorm + quant before attn + in the first Transformer block. + """ + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + allreduce_params: FlashInferFusedAllReduceParams): + super().__init__(dtype, device) + self.epsilon = epsilon + self.allreduce_params = allreduce_params + + def get_inputs(self): + input = torch.empty([16, 16], device=self.device, dtype=self.dtype) + + quant_result = torch.empty((16, 8), + device=self.device, + dtype=torch.uint8) + input_global_scale = torch.empty([1, 1], + device=self.device, + dtype=torch.float32) + weight = torch.empty([16], device=self.device, dtype=self.dtype) + output_scale = torch.empty([128, 4], + device=self.device, + dtype=torch.int32) + rmsnorm_result = torch.empty([16, 16], + device=self.device, + dtype=self.dtype) + return [ + input, quant_result, weight, input_global_scale, output_scale, + rmsnorm_result + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + input: torch.Tensor, + quant_result: torch.Tensor, + weight: torch.Tensor, + input_global_scale: torch.Tensor, + output_scale: torch.Tensor, + rmsnorm_result: torch.Tensor, + ): + all_reduce = tensor_model_parallel_all_reduce(input) + rmsnorm_out_tuple = auto_functionalized(RMS_OP, + result=rmsnorm_result, + input=all_reduce, + weight=weight, + epsilon=self.epsilon) + + quant_out_tuple = auto_functionalized( + STATIC_FP4_QUANT_OP, + output=quant_result, + input=rmsnorm_out_tuple[1], output_scale=output_scale, input_scale=input_global_scale) # quant_out, allreduce_output, output_scale + return quant_out_tuple[1], all_reduce, quant_out_tuple[2] + + def replacement(input: torch.Tensor, quant_result: torch.Tensor, + weight: torch.Tensor, input_global_scale: torch.Tensor, + output_scale: torch.Tensor, result_rms: torch.Tensor): + residual = torch.zeros_like(input) + allreduce = auto_functionalized( + flashinfer_trtllm_fused_allreduce_norm, + allreduce_in=input, + residual=residual, + norm_out=result_rms, + quant_out=quant_result, + scale_out=output_scale, + rms_gamma=weight, + rms_eps=self.epsilon, + pattern_code=flashinfer_comm.AllReduceFusionPattern. + kARResidualRMSNormFP4Quant, + scale_factor=input_global_scale, + **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + # quant_out, allreduce_output, output_scale + return allreduce[4], allreduce[1], allreduce[5] + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class AllReduceFusedAddRMSNormNativeStaticQuantNVFP4Pattern(BasePattern): + """ + This pattern replaces the allreduce + torch native rms norm (with residual) + + static nvfp4 quant with fused flashinfer implementation. + Applies to o_proj + rmsnorm after attn + quant and + mlp + rmsnorm + quant before attn. + """ + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + allreduce_params: FlashInferFusedAllReduceParams): + super().__init__(dtype, device) + self.epsilon = epsilon + self.allreduce_params = allreduce_params + + def get_inputs(self): + input = torch.empty([16, 16], device=self.device, dtype=self.dtype) + + residual = torch.empty([16, 16], + device=self.device, + dtype=torch.float32) + weight = torch.empty([16, 16], device=self.device, dtype=self.dtype) + quant_result = torch.empty((16, 8), + device=self.device, + dtype=torch.uint8) + input_global_scale = torch.empty([1, 1], + device=self.device, + dtype=torch.float32) + output_scale = torch.empty([128, 4], + device=self.device, + dtype=torch.int32) + + return [ + quant_result, + residual, + input, + output_scale, + weight, + input_global_scale, + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern(quant_result: torch.Tensor, residual: torch.Tensor, + input: torch.Tensor, output_scale: torch.Tensor, + weight: torch.Tensor, input_global_scale: torch.Tensor): + allreduce_output = tensor_model_parallel_all_reduce(input) + rmsnorm_out, rms_residual = rms_norm_native( + allreduce_output, weight, self.epsilon, residual) + quant_out_tuple = auto_functionalized( + STATIC_FP4_QUANT_OP, + output=quant_result, + input=rmsnorm_out, + output_scale=output_scale, + input_scale=input_global_scale) + + # quant_out, residual, output_scale + return quant_out_tuple[1], rms_residual, quant_out_tuple[2] + + def replacement(quant_result: torch.Tensor, residual: torch.Tensor, + input: torch.Tensor, output_scale: torch.Tensor, + weight: torch.Tensor, + input_global_scale: torch.Tensor): + allreduce = auto_functionalized( + flashinfer_trtllm_fused_allreduce_norm, + allreduce_in=input, + residual=residual.to(self.dtype), + norm_out=None, + quant_out=quant_result, + scale_out=output_scale, + rms_gamma=weight, + rms_eps=self.epsilon, + pattern_code=flashinfer_comm.AllReduceFusionPattern. + kARResidualRMSNormFP4Quant, # we don't use norm_out afterwards + scale_factor=input_global_scale, + **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + # quant_out, rms_norm_residual, output_scale + return allreduce[4], allreduce[2], allreduce[5] + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class AllReduceFusedAddRMSNormCustomOpStaticQuantNVFP4Pattern(BasePattern): + """ + This pattern replaces the allreduce + custom op rms norm (with residual) + + static nvfp4 quant with fused flashinfer implementation. + Applies to o_proj + rmsnorm after attn + quant and + mlp + rmsnorm + quant before attn. + """ + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + allreduce_params: FlashInferFusedAllReduceParams): + super().__init__(dtype, device) + self.epsilon = epsilon + self.allreduce_params = allreduce_params + + def get_inputs(self): + input = torch.empty([16, 16], device=self.device, dtype=self.dtype) + + residual = torch.empty([16, 16], device=self.device, dtype=self.dtype) + weight = torch.empty([16, 16], device=self.device, dtype=self.dtype) + quant_result = torch.empty((16, 8), + device=self.device, + dtype=torch.uint8) + input_global_scale = torch.empty([1, 1], + device=self.device, + dtype=torch.float32) + output_scale = torch.empty([128, 4], + device=self.device, + dtype=torch.int32) + + return [ + quant_result, + residual, + input, + output_scale, + weight, + input_global_scale, + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern(quant_result: torch.Tensor, residual: torch.Tensor, + input: torch.Tensor, output_scale: torch.Tensor, + weight: torch.Tensor, input_global_scale: torch.Tensor): + allreduce_output = tensor_model_parallel_all_reduce(input) + + fused_add_rmsnorm_out_tuple = \ + auto_functionalized( + RMS_ADD_OP, + input=allreduce_output, + residual=residual, + weight=weight, + epsilon=self.epsilon) + quant_out_tuple = auto_functionalized( + STATIC_FP4_QUANT_OP, + output=quant_result, + input=fused_add_rmsnorm_out_tuple[1], + output_scale=output_scale, + input_scale=input_global_scale) + + # quant_out, residual, output_scale return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[ 2], quant_out_tuple[2] @@ -1064,7 +1767,7 @@ def replacement(quant_result: torch.Tensor, residual: torch.Tensor, # quant_out, rms_norm_residual, output_scale return allreduce[4], allreduce[2], allreduce[5] - pm.register_replacement(pattern, replacement, get_inputs(), + pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) @@ -1089,24 +1792,28 @@ def __init__(self, config: VllmConfig): "Flashinfer is not installed or comm module not found, " "skipping allreduce fusion pass") return - # Check if the world size is supported - if self.tp_size not in _FI_MAX_SIZES: + max_size = config.compilation_config.\ + pass_config.flashinfer_max_size(self.tp_size) + if max_size is None: + # Flashinfer doesn't support current world size logger.warning( "Flashinfer allreduce fusion is not " "supported for world size %s", self.tp_size, ) return - max_num_token = min( - _FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE) // - (self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)), - config.compilation_config.pass_config. - fi_allreduce_fusion_max_token_num) + element_size = 4 if use_fp32_lamport else 2 + max_token_num = (max_size // (self.hidden_dim * element_size)) + # take the min to save workspace size and we'll never use more + # than max_num_batched_tokens anyways + max_token_num = min(max_token_num, + config.scheduler_config.max_num_batched_tokens) + self.ipc_handles, workspace_tensor = ( flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( tp_rank=rank, tp_size=self.tp_size, - max_token_num=max_num_token, + max_token_num=max_token_num, hidden_dim=self.hidden_dim, group=self.group, use_fp32_lamport=use_fp32_lamport, @@ -1118,48 +1825,117 @@ def __init__(self, config: VllmConfig): rank=rank, world_size=self.tp_size, use_fp32_lamport=use_fp32_lamport, - max_token_num=max_num_token, - # fuse rms norm static fp8 quant fused op - # in fallback path, when we don't use flashinfer - fuse_rms_quant=config.compilation_config.pass_config.enable_fusion) - - self.register_patterns() + max_token_num=max_token_num, + ) + with set_current_vllm_config(config), torch.device(self.device): + self.register_patterns() @enable_fake_mode def register_patterns(self): for epsilon in [1e-5, 1e-6]: - AllReduceFusedRMSNormStaticQuantFP8Pattern( + # rms norm + static fp8 quant + AllReduceFusedRMSNormNativeStaticQuantFP8NativePattern( + epsilon, + self.model_dtype, + self.device, + self.allreduce_params, + ).register(self.patterns) + AllReduceFusedRMSNormNativeStaticQuantFP8CustomOpPattern( + epsilon, + self.model_dtype, + self.device, + self.allreduce_params, + ).register(self.patterns) + AllReduceFusedRMSNormCustomOpStaticQuantFP8NativePattern( + epsilon, + self.model_dtype, + self.device, + self.allreduce_params, + ).register(self.patterns) + AllReduceFusedRMSNormCustomOpStaticQuantFP8CustomOpPattern( + epsilon, + self.model_dtype, + self.device, + self.allreduce_params, + ).register(self.patterns) + + # add rms norm + static fp8 quant + AllReduceFusedAddRMSNormNativeStaticQuantFP8NativePattern( + epsilon, + self.model_dtype, + self.device, + self.allreduce_params, + ).register(self.patterns) + AllReduceFusedAddRMSNormNativeStaticQuantFP8CustomOpPattern( + epsilon, + self.model_dtype, + self.device, + self.allreduce_params, + ).register(self.patterns) + AllReduceFusedAddRMSNormCustomOpStaticQuantFP8NativePattern( epsilon, self.model_dtype, self.device, self.allreduce_params, ).register(self.patterns) - AllReduceFusedAddRMSNormStaticQuantFP8Pattern( + AllReduceFusedAddRMSNormCustomOpStaticQuantFP8CustomOpPattern( epsilon, self.model_dtype, self.device, self.allreduce_params, ).register(self.patterns) + if current_platform.has_device_capability(100): - AllReduceFusedRMSNormStaticQuantNVFP4Pattern( + # rms norm + static nvfp4 quant + AllReduceFusedRMSNormNativeStaticQuantNVFP4Pattern( + epsilon, + self.model_dtype, + self.device, + self.allreduce_params, + ).register(self.patterns) + AllReduceFusedRMSNormCustomOpStaticQuantNVFP4Pattern( + epsilon, + self.model_dtype, + self.device, + self.allreduce_params, + ).register(self.patterns) + + # add rms norm + static nvfp4 quant + AllReduceFusedAddRMSNormNativeStaticQuantNVFP4Pattern( epsilon, self.model_dtype, self.device, self.allreduce_params, ).register(self.patterns) - AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern( + AllReduceFusedAddRMSNormCustomOpStaticQuantNVFP4Pattern( epsilon, self.model_dtype, self.device, self.allreduce_params, ).register(self.patterns) - AllReduceRMSNormPattern( + + # rms norm + AllReduceRMSNormNativePattern( + epsilon, + self.model_dtype, + self.device, + self.allreduce_params, + ).register(self.patterns) + AllReduceRMSNormCustomOpPattern( + epsilon, + self.model_dtype, + self.device, + self.allreduce_params, + ).register(self.patterns) + + # add rms norm + AllReduceFusedAddRMSNormNativePattern( epsilon, self.model_dtype, self.device, self.allreduce_params, ).register(self.patterns) - AllReduceFusedAddRMSNormPattern( + AllReduceFusedAddRMSNormCustomOpPattern( epsilon, self.model_dtype, self.device, diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 677fb069bc07..039e92cf0ef1 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -87,11 +87,66 @@ class PassConfig: """Whether to enable async TP.""" enable_fi_allreduce_fusion: bool = False """Whether to enable flashinfer allreduce fusion.""" - fi_allreduce_fusion_max_token_num: int = 16384 - """Max number of tokens to used in flashinfer allreduce fusion.""" + fi_allreduce_fusion_max_size_mb: dict[int, + float] = field(default_factory=dict) + """The thresholds of the communicated tensor sizes under which + vllm should use flashinfer fused allreduce. Specified as a + dictionary mapping each world size to the threshold in MB + { : } + Unspecified world sizes will fallback to + _FI_ALLREDUCE_MAX_INPUT_SIZES = { + "9.0": { + 2: 64 * MiB, # 64MB + 4: 2 * MiB, # 2MB + 8: 1 * MiB, # 1MB + }, + "10.0": { + 2: 64 * MiB, # 64MB + 4: 32 * MiB, # 32MB + 8: 1 * MiB, # 1MB + }, + }, where key is the device capability""" # TODO(luka) better pass enabling system. + def flashinfer_max_size(self, world_size: int) -> Optional[int]: + """ + Returns the max communication size in bytes for flashinfer + allreduce fusion for the given world size. Falls back to + conservative defaults if the world size is not specified in config. + """ + + # import here to avoid circular dependencies + from vllm.platforms import current_platform + MiB = 1024 * 1024 + + # Max size of the input tensor per world size per device capability + # to use flashinfer fused allreduce + _FI_ALLREDUCE_MAX_INPUT_SIZES = { + "9.0": { + 2: 64 * MiB, # 64MB + 4: 2 * MiB, # 2MB + 8: 1 * MiB, # 1MB + }, + "10.0": { + 2: 64 * MiB, # 64MB + 4: 32 * MiB, # 32MB + 8: 1 * MiB, # 1MB + }, + } + + device_capability = current_platform.get_device_capability( + ).as_version_str() + max_sizes = _FI_ALLREDUCE_MAX_INPUT_SIZES.get(device_capability, {}) + max_sizes.update({ + k: int(v * MiB) + for k, v in self.fi_allreduce_fusion_max_size_mb.items() + }) + if world_size not in max_sizes: + # FlashInfer doesn't support other world sizes + return None + return max_sizes[world_size] + def uuid(self): """ Produces a hash unique to the pass configuration. @@ -110,6 +165,10 @@ def __post_init__(self) -> None: logger.warning_once( "Fusion enabled but reshape elimination disabled. " "Attention + quant (fp8) fusion might not work") + if self.enable_fi_allreduce_fusion: + logger.warning_once( + "Fusion enabled but reshape elimination disabled. " + "Allreduce + rms norm + quant (fp8) fusion might not work") @config From 0fe1de448ad152103650c006a464e8af89cacf15 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Fri, 5 Sep 2025 05:58:33 -0700 Subject: [PATCH 2/4] Update bench Signed-off-by: ilmarkov --- .../kernels/benchmark_fused_collective.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py index ea78875c62cf..7f012af36a94 100644 --- a/benchmarks/kernels/benchmark_fused_collective.py +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -187,7 +187,7 @@ def flashinfer_fused_allreduce_rmsnorm( allreduce_out=None, quant_out=None, scale_out=None, - layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4_, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, scale_factor=None, use_oneshot=use_oneshot, **allreduce_params.get_trtllm_fused_allreduce_kwargs(), @@ -962,10 +962,15 @@ def get_fastest_baseline(op_name, results_dict): return prepared_results -def print_results(results_dict, seq_len, hidden_dim, dtype, use_residual, quant_mode): +def print_results( + results_dict, seq_len, hidden_dim, dtype, use_residual, quant_mode, input_size_mb +): """Print benchmark results in a formatted table.""" print(f"\n{'=' * 80}") - print(f"Results: seq_len={seq_len}, hidden_dim={hidden_dim}") + print( + f"Results: seq_len={seq_len}, hidden_dim={hidden_dim} " + f"(input size: {input_size_mb:.2f} MB)" + ) print( f"dtype={dtype}, residual={'yes' if use_residual else 'no'}, " f"quant_mode={quant_mode}" @@ -1009,11 +1014,12 @@ def format_results_markdown( dtype = result["dtype"] use_residual = result["use_residual"] results_dict = result["results"] - + input_size_mb = result["input_size_mb"] residual_str = "with residual" if use_residual else "no residual" markdown += f""" ## Configuration: seq_len={seq_len}, dtype={dtype}, {residual_str} +**Input Size:** {input_size_mb:.2f} MB | Operation | Time (ms) | Speedup | |-----------|-----------|---------| @@ -1234,6 +1240,10 @@ def main(): # Store results for markdown export if rank == 0: + # Calculate input size in MB + input_size_mb = ( + seq_len * args.hidden_dim * torch.finfo(dtype).bits + ) / (8 * 1024 * 1024) all_results.append( { "seq_len": seq_len, @@ -1241,6 +1251,7 @@ def main(): "dtype": str(dtype).replace("torch.", ""), "use_residual": use_residual, "quant_mode": quant_mode, + "input_size_mb": input_size_mb, "results": results, } ) @@ -1252,6 +1263,7 @@ def main(): dtype, use_residual, quant_mode, + input_size_mb, ) # Save results to markdown file From fde8bd18163642ae680809b23afaa64ba2a686ea Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Mon, 8 Sep 2025 04:41:09 -0700 Subject: [PATCH 3/4] Update threshold configuration Signed-off-by: ilmarkov --- vllm/config/compilation.py | 58 ++++++++++++++++++++++---------------- 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 039e92cf0ef1..6b000b4ed12a 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -112,40 +112,24 @@ class PassConfig: def flashinfer_max_size(self, world_size: int) -> Optional[int]: """ Returns the max communication size in bytes for flashinfer - allreduce fusion for the given world size. Falls back to - conservative defaults if the world size is not specified in config. + allreduce fusion for the given world size. Returns None if world size + is not supported by configs as it's not supported by flashinfer. """ # import here to avoid circular dependencies from vllm.platforms import current_platform MiB = 1024 * 1024 - # Max size of the input tensor per world size per device capability - # to use flashinfer fused allreduce - _FI_ALLREDUCE_MAX_INPUT_SIZES = { - "9.0": { - 2: 64 * MiB, # 64MB - 4: 2 * MiB, # 2MB - 8: 1 * MiB, # 1MB - }, - "10.0": { - 2: 64 * MiB, # 64MB - 4: 32 * MiB, # 32MB - 8: 1 * MiB, # 1MB - }, - } - device_capability = current_platform.get_device_capability( ).as_version_str() - max_sizes = _FI_ALLREDUCE_MAX_INPUT_SIZES.get(device_capability, {}) - max_sizes.update({ + fi_allreduce_fusion_max_size_mb = \ + self.fi_allreduce_fusion_max_size_mb.get(device_capability, {}) + max_sizes = { k: int(v * MiB) - for k, v in self.fi_allreduce_fusion_max_size_mb.items() - }) - if world_size not in max_sizes: - # FlashInfer doesn't support other world sizes - return None - return max_sizes[world_size] + for k, v in fi_allreduce_fusion_max_size_mb.items() + } + # return None if world size is not supported by flashinfer + return max_sizes.get(world_size) def uuid(self): """ @@ -169,6 +153,30 @@ def __post_init__(self) -> None: logger.warning_once( "Fusion enabled but reshape elimination disabled. " "Allreduce + rms norm + quant (fp8) fusion might not work") + # import here to avoid circular dependencies + from vllm.platforms import current_platform + + # Default tuned max size of the input tensor + # per world size per device capability + # to use flashinfer fused allreduce + fi_allreduce_fusion_max_size_mb = { + "9.0": { + 2: 64, # 64MB + 4: 2, # 2MB + 8: 1, # 1MB + }, + "10.0": { + 2: 64, # 64MB + 4: 32, # 32MB + 8: 1, # 1MB + }, + } + device_capability = current_platform.get_device_capability( + ).as_version_str() + + max_sizes = fi_allreduce_fusion_max_size_mb.get(device_capability, {}) + max_sizes.update(self.fi_allreduce_fusion_max_size_mb) + self.fi_allreduce_fusion_max_size_mb[device_capability] = max_sizes @config From 61ebc9566e2e9df9437d3b5a14a63fac42ec4723 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Mon, 8 Sep 2025 05:01:47 -0700 Subject: [PATCH 4/4] Move all_reduce from custom op in fused_moe Signed-off-by: ilmarkov --- vllm/model_executor/layers/fused_moe/layer.py | 37 +++++++++---------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 272ad3956537..2cf2dfb4398d 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1599,6 +1599,19 @@ def forward( (0, self.hidden_size - og_hidden_states), mode='constant', value=0.0) + do_naive_dispatch_combine: bool = ( + self.dp_size > 1 + and not self.moe_parallel_config.use_deepep_ht_kernels + and not self.moe_config.use_flashinfer_cutlass_kernels) + + def reduce_output(states: torch.Tensor) -> torch.Tensor: + if do_naive_dispatch_combine: + states = get_ep_group().combine(states) + + if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + states = self.maybe_all_reduce_tensor_model_parallel(states) + + return states if self.shared_experts is None: if current_platform.is_tpu(): @@ -1609,7 +1622,7 @@ def forward( else: fused_output = torch.ops.vllm.moe_forward( hidden_states, router_logits, self.layer_name) - return fused_output[..., :og_hidden_states] + return reduce_output(fused_output[..., :og_hidden_states]) else: if current_platform.is_tpu(): # TODO: Once the OOM issue for the TPU backend is resolved, we @@ -1619,8 +1632,8 @@ def forward( else: shared_output, fused_output = torch.ops.vllm.moe_forward_shared( hidden_states, router_logits, self.layer_name) - return (shared_output[..., :og_hidden_states], - fused_output[..., :og_hidden_states]) + return (reduce_output(shared_output[..., :og_hidden_states]), + reduce_output(fused_output[..., :og_hidden_states])) def forward_impl_chunked( self, @@ -1786,23 +1799,7 @@ def forward_impl( shared_output, final_hidden_states, ) - - def reduce_output(states: torch.Tensor) -> torch.Tensor: - if do_naive_dispatch_combine: - states = get_ep_group().combine(states) - - if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): - states = self.maybe_all_reduce_tensor_model_parallel(states) - - return states - - if self.shared_experts is None: - return reduce_output(final_hidden_states) - else: - return ( - reduce_output(final_hidden_states[0]), - reduce_output(final_hidden_states[1]), - ) + return final_hidden_states @classmethod def make_expert_params_mapping(