diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index be7044c41a73..39e7b40f5a87 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -343,6 +343,7 @@ steps: - pytest -v -s compile/test_async_tp.py - pytest -v -s compile/test_fusion_all_reduce.py - pytest -v -s compile/test_decorator.py + - pytest -v -s compile/test_compile_ranges.py - label: PyTorch Fullgraph Smoke Test # 9min mirror_hardwares: [amdexperimental] diff --git a/tests/compile/test_compile_ranges.py b/tests/compile/test_compile_ranges.py new file mode 100644 index 000000000000..6759da199f4b --- /dev/null +++ b/tests/compile/test_compile_ranges.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +from torch import nn +from torch.library import Library + +from vllm.compilation.counter import compilation_counter +from vllm.compilation.decorators import support_torch_compile +from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, + set_current_vllm_config) +from vllm.forward_context import set_forward_context +from vllm.utils import direct_register_custom_op + +# create a library to hold the custom op +silly_lib = Library("silly", "FRAGMENT") # noqa + +BATCH_SIZE = 64 +MLP_SIZE = 128 + + +def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + out.copy_(q) + out += k + out += v + + +def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + return + + +direct_register_custom_op( + op_name="attention", + op_func=silly_attention, + mutates_args=["out"], + fake_impl=silly_attention_fake, + target_lib=silly_lib, +) + + +@support_torch_compile +class TestModel(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + x + attn_output = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, attn_output) + x = attn_output + x = x * 3 + return x + + +@torch.inference_mode +def run_model(vllm_config: VllmConfig, model: nn.Module, + batch_sizes: list[int]): + with set_forward_context({}, vllm_config=vllm_config): + model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) + for batch_size in batch_sizes: + model(torch.randn(batch_size, MLP_SIZE).cuda()) + + +def test_compile_ranges(): + vllm_config = VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + compile_ranges_split_points=[8, 32], + )) + + with set_current_vllm_config(vllm_config): + model = TestModel(vllm_config=vllm_config, prefix='').eval().cuda() + batch_sizes = [1, 16, 48] + # A has support_torch_compile + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=1, + num_backend_compilations=4, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + run_model(vllm_config, model, batch_sizes) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 3361b65a9b88..24b032bf8d63 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -59,7 +59,8 @@ class CompilerManager: """ def __init__(self, compilation_config: CompilationConfig): - self.cache: dict[tuple[Optional[int], int, str], Any] = dict() + self.cache: dict[tuple[Optional[tuple[int, int]], int, str], + Any] = (dict()) self.is_cache_updated = False self.compilation_config = compilation_config self.compiler = make_compiler(compilation_config) @@ -110,35 +111,47 @@ def save_to_file(self): with open(self.cache_file_path, "w") as f: f.write(data) - def load(self, - graph: fx.GraphModule, - example_inputs: list[Any], - graph_index: int, - runtime_shape: Optional[int] = None) -> Optional[Callable]: - if (runtime_shape, graph_index, self.compiler.name) not in self.cache: + def load( + self, + graph: fx.GraphModule, + example_inputs: list[Any], + graph_index: int, + compile_range: Optional[tuple[int, int]] = None, + ) -> Optional[Callable]: + if (compile_range, graph_index, self.compiler.name) not in self.cache: return None - handle = self.cache[(runtime_shape, graph_index, self.compiler.name)] + handle = self.cache[(compile_range, graph_index, self.compiler.name)] compiled_graph = self.compiler.load(handle, graph, example_inputs, - graph_index, runtime_shape) - if runtime_shape is None: + graph_index, compile_range) + if compile_range is None: logger.debug( - "Directly load the %s-th graph for dynamic shape from %s via " - "handle %s", graph_index, self.compiler.name, handle) + "Directly load the %s-th graph for dynamic shape " + "from %s via handle %s", + graph_index, + self.compiler.name, + handle, + ) else: logger.debug( - "Directly load the %s-th graph for shape %s from %s via " - "handle %s", graph_index, str(runtime_shape), - self.compiler.name, handle) + "Directly load the %s-th graph for compile range %s " + "from %s via handle %s", + graph_index, + str(compile_range), + self.compiler.name, + handle, + ) return compiled_graph - def compile(self, - graph: fx.GraphModule, - example_inputs, - additional_inductor_config, - compilation_config: CompilationConfig, - graph_index: int = 0, - num_graphs: int = 1, - runtime_shape: Optional[int] = None) -> Any: + def compile( + self, + graph: fx.GraphModule, + example_inputs, + additional_inductor_config, + compilation_config: CompilationConfig, + graph_index: int = 0, + num_graphs: int = 1, + compile_range: Optional[tuple[int, int]] = None, + ) -> Any: if graph_index == 0: # before compiling the first graph, record the start time global compilation_start_time @@ -150,22 +163,27 @@ def compile(self, # try to load from the cache compiled_graph = self.load(graph, example_inputs, graph_index, - runtime_shape) + compile_range) if compiled_graph is not None: if graph_index == num_graphs - 1: # after loading the last graph for this shape, record the time. # there can be multiple graphs due to piecewise compilation. now = time.time() elapsed = now - compilation_start_time - if runtime_shape is None: + if compile_range is None: logger.info( "Directly load the compiled graph(s) for dynamic shape " - "from the cache, took %.3f s", elapsed) + "from the cache, took %.3f s", + elapsed, + ) else: logger.info( - "Directly load the compiled graph(s) for shape %s " - "from the cache, took %.3f s", str(runtime_shape), - elapsed) + "Directly load the compiled graph(s) " + "for compile range %s " + "from the cache, took %.3f s", + str(compile_range), + elapsed, + ) return compiled_graph # no compiler cached the graph, or the cache is disabled, @@ -174,49 +192,65 @@ def compile(self, # Let compile_fx generate a key for us maybe_key = None else: - maybe_key = \ - f"artifact_shape_{runtime_shape}_subgraph_{graph_index}" + maybe_key = f"artifact_shape_{compile_range}_subgraph_{graph_index}" compiled_graph, handle = self.compiler.compile( - graph, example_inputs, additional_inductor_config, runtime_shape, - maybe_key) + graph, + example_inputs, + additional_inductor_config, + compile_range, + maybe_key, + ) assert compiled_graph is not None, "Failed to compile the graph" # store the artifact in the cache if not envs.VLLM_DISABLE_COMPILE_CACHE and handle is not None: - self.cache[(runtime_shape, graph_index, - self.compiler.name)] = handle + self.cache[(compile_range, graph_index, + self.compiler.name)] = (handle) compilation_counter.num_cache_entries_updated += 1 self.is_cache_updated = True if graph_index == 0: # adds some info logging for the first graph - if runtime_shape is None: + if compile_range is None: logger.info( "Cache the graph for dynamic shape for later use") else: - logger.info("Cache the graph of shape %s for later use", - str(runtime_shape)) - if runtime_shape is None: + logger.info( + "Cache the graph of compile range %s for later use", + str(compile_range), + ) + if compile_range is None: logger.debug( "Store the %s-th graph for dynamic shape from %s via " - "handle %s", graph_index, self.compiler.name, handle) + "handle %s", + graph_index, + self.compiler.name, + handle, + ) else: logger.debug( - "Store the %s-th graph for shape %s from %s via handle %s", - graph_index, str(runtime_shape), self.compiler.name, - handle) + "Store the %s-th graph for compile range %s " + "from %s via handle %s", + graph_index, + str(compile_range), + self.compiler.name, + handle, + ) # after compiling the last graph, record the end time if graph_index == num_graphs - 1: now = time.time() elapsed = now - compilation_start_time compilation_config.compilation_time += elapsed - if runtime_shape is None: + if compile_range is None: logger.info("Compiling a graph for dynamic shape takes %.2f s", elapsed) else: - logger.info("Compiling a graph for shape %s takes %.2f s", - runtime_shape, elapsed) + logger.info( + "Compiling a graph for compile range %s takes %.2f s", + compile_range, + elapsed, + ) return compiled_graph @@ -238,7 +272,7 @@ def split_graph(graph: fx.GraphModule, for node in graph.graph.nodes: if node.op in ("output", "placeholder"): continue - if node.op == 'call_function' and str(node.target) in ops: + if node.op == "call_function" and str(node.target) in ops: subgraph_id += 1 node_to_subgraph_id[node] = subgraph_id split_op_graphs.append(subgraph_id) @@ -254,7 +288,8 @@ def split_graph(graph: fx.GraphModule, graph, None, lambda node: node_to_subgraph_id[node], - keep_original_order=True) + keep_original_order=True, + ) outputs = [] @@ -297,6 +332,7 @@ def __init__(self, module: torch.fx.GraphModule, vllm_backend: "VllmBackend"): super().__init__(module) from torch._guards import detect_fake_mode + self.fake_mode = detect_fake_mode() self.compile_submod_names = compile_submod_names self.compilation_config = vllm_config.compilation_config @@ -313,9 +349,12 @@ def run(self, *args): with self.fake_mode, enable_python_dispatcher(): return super().run(*fake_args) - def call_module(self, target: torch.fx.node.Target, - args: tuple[torch.fx.node.Argument, - ...], kwargs: dict[str, Any]) -> Any: + def call_module( + self, + target: torch.fx.node.Target, + args: tuple[torch.fx.node.Argument, ...], + kwargs: dict[str, Any], + ) -> Any: assert isinstance(target, str) output = super().call_module(target, args, kwargs) @@ -325,24 +364,20 @@ def call_module(self, target: torch.fx.node.Target, sym_shape_indices = [ i for i, x in enumerate(args) if isinstance(x, torch.SymInt) ] - global compilation_start_time - compiled_graph_for_dynamic_shape = self.vllm_backend.\ - compiler_manager.compile( - submod, - args, - self.compilation_config.inductor_compile_config, - self.compilation_config, - graph_index=index, - num_graphs=len(self.compile_submod_names), - runtime_shape=None) + # Lazy import here to avoid circular import from .cuda_graph import CUDAGraphOptions from .cuda_piecewise_backend import PiecewiseBackend piecewise_backend = PiecewiseBackend( - submod, self.vllm_config, index, - len(self.compile_submod_names), sym_shape_indices, - compiled_graph_for_dynamic_shape, self.vllm_backend) + submod, + self.vllm_config, + index, + len(self.compile_submod_names), + sym_shape_indices, + # compiled_graph_for_dynamic_shape, + self.vllm_backend, + ) if self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: # resolve the static graph wrapper class (e.g. CUDAGraphWrapper @@ -361,7 +396,9 @@ def call_module(self, target: torch.fx.node.Target, cudagraph_options=CUDAGraphOptions( debug_log_enable=piecewise_backend.is_first_graph, gc_disable=not piecewise_backend.is_first_graph, - weak_ref_output=piecewise_backend.is_last_graph)) + weak_ref_output=piecewise_backend.is_last_graph, + ), + ) else: self.module.__dict__[target] = piecewise_backend @@ -379,8 +416,8 @@ def call_module(self, target: torch.fx.node.Target, def set_model_tag(tag: str): """Context manager to set the model tag.""" global model_tag - assert tag != model_tag, \ - f"Model tag {tag} is the same as the current tag {model_tag}." + assert (tag != model_tag + ), f"Model tag {tag} is the same as the current tag {model_tag}." old_tag = model_tag model_tag = tag try: @@ -490,7 +527,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self.compilation_config.traced_files.clear() logger.debug( "Traced files (to be considered for compilation cache):\n%s", - "\n".join(forward_code_files)) + "\n".join(forward_code_files), + ) hash_content = [] for filepath in forward_code_files: hash_content.append(filepath) @@ -501,6 +539,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: with open(filepath) as f: hash_content.append(f.read()) import hashlib + code_hash = hashlib.md5("\n".join(hash_content).encode(), usedforsecurity=False).hexdigest() factors.append(code_hash) @@ -535,8 +574,10 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: if disable_cache: logger.info("vLLM's torch.compile cache is disabled.") else: - logger.info("Using cache directory: %s for vLLM's torch.compile", - local_cache_dir) + logger.info( + "Using cache directory: %s for vLLM's torch.compile", + local_cache_dir, + ) self.compiler_manager.initialize_cache(local_cache_dir, disable_cache, self.prefix) @@ -545,6 +586,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: # transform and analysis are done compilation_counter.num_graphs_seen += 1 from .monitor import torch_compile_start_time + dynamo_time = time.time() - torch_compile_start_time logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time) self.compilation_config.compilation_time += dynamo_time @@ -583,8 +625,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: if not os.path.exists(graph_path): # code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa # use `print_readable` because it can include submodules - src = "from __future__ import annotations\nimport torch\n" + \ - self.split_gm.print_readable(print_output=False) + src = ("from __future__ import annotations\nimport torch\n" + + self.split_gm.print_readable(print_output=False)) src = src.replace("", "GraphModule") with open(graph_path, "w") as f: f.write(src) @@ -593,12 +635,13 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self._called = True - if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE or \ - not self.compilation_config.cudagraph_copy_inputs: + if (self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE + or not self.compilation_config.cudagraph_copy_inputs): return self.split_gm # if we need to copy input buffers for cudagraph from torch._guards import detect_fake_mode + fake_mode = detect_fake_mode() fake_args = [ fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t @@ -609,10 +652,11 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: # for weights and static buffers, they will have concrete shapes. # symbolic shape only happens for input tensors. from torch.fx.experimental.symbolic_shapes import is_symbolic + self.sym_tensor_indices = [ i for i, x in enumerate(fake_args) - if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) and \ - any(is_symbolic(d) for d in x.size()) + if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) and any( + is_symbolic(d) for d in x.size()) ] # compiler managed cudagraph input buffers diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 71274420c342..97564e4f3602 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -10,7 +10,6 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from torch.distributed._symmetric_memory import enable_symm_mem_for_group -import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( @@ -396,30 +395,21 @@ def __call__(self, graph: fx.Graph): _FI_WORKSPACE_TENSOR = None MiB = 1024 * 1024 - # Max size of the input tensor per world size - # to use flashinfer fused allreduce - _FI_MAX_SIZES = { - 2: 64 * MiB, # 64MB - 4: MiB, # 1MB - 6: MiB // 2, # 512KB - 8: MiB // 2, # 512KB + # Max size of the input tensor per world size per device capability + # to use flashinfer one shot fused allreduce + _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES = { + "9.0": { + 2: 32 * MiB, # 32MB + 4: 2 * MiB, # 2MB + 8: 1 * MiB, # 1MB + }, + "10.0": { + 2: 32 * MiB, # 32MB + 4: 4 * MiB, # 4MB + 8: 1 * MiB, # 1MB + }, } - try: - _FI_MAX_SIZES.update({ - int(k): int(float(v) * MiB) - for k, v in - envs.VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB.items() - }) - except Exception as e: - raise ValueError( - "Failed to parse VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB: " - + str(e)) from e - - # opt for a more conservative default value - # when world size is not in _FI_MAX_SIZES - _DEFAULT_FI_MAX_SIZE = MiB // 2 - def call_trtllm_fused_allreduce_norm( allreduce_in: torch.Tensor, residual: torch.Tensor, @@ -432,7 +422,6 @@ def call_trtllm_fused_allreduce_norm( fp32_acc: bool, max_token_num: int, pattern_code: int, - fuse_rms_quant: bool, norm_out: Optional[torch.Tensor] = None, quant_out: Optional[torch.Tensor] = None, scale_out: Optional[torch.Tensor] = None, @@ -441,82 +430,59 @@ def call_trtllm_fused_allreduce_norm( num_tokens, hidden_size = allreduce_in.shape element_size = allreduce_in.element_size() current_tensor_size = num_tokens * hidden_size * element_size - max_fusion_size = max_token_num * hidden_size * element_size - use_flashinfer = current_tensor_size <= min( - _FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE), - max_fusion_size, - ) - if use_flashinfer: - assert (_FI_WORKSPACE_TENSOR is not None - ), "Flashinfer must be enabled when using flashinfer" - if norm_out is None: - norm_out = allreduce_in - residual_out = residual - else: - # return residual_out as allreduce_out with zeroed residual_in - # as flashinfer does not support rms_norm - # and allreduce_out together - residual_out = allreduce_in - # For the sizes that are smaller than the max size, - # we only use flashinfer one shot allreduce - flashinfer_comm.trtllm_allreduce_fusion( - allreduce_in=allreduce_in, - token_num=allreduce_in.shape[0], - residual_in=residual, - residual_out=residual_out, - norm_out=norm_out, - rms_gamma=rms_gamma, - rms_eps=rms_eps, - world_rank=world_rank, - world_size=world_size, - hidden_dim=allreduce_in.shape[-1], - workspace_ptrs=_FI_WORKSPACE_TENSOR, - launch_with_pdl=launch_with_pdl, - use_oneshot=True, - trigger_completion_at_end=trigger_completion_at_end, - fp32_acc=fp32_acc, - pattern_code=pattern_code, - allreduce_out=None, - quant_out=quant_out, - scale_out=scale_out, - # in vllm we only support swizzled layout - layout_code=flashinfer_comm.QuantizationSFLayout. - SWIZZLED_128x4, - scale_factor=scale_factor, - ) + max_tensor_size = max_token_num * hidden_size * element_size + assert current_tensor_size <= max_tensor_size, \ + f"Current tensor size {current_tensor_size} is larger than " \ + f"max token num {max_token_num} * hidden size {hidden_size} * " \ + f"element size {element_size}" + device_capability = current_platform.get_device_capability( + ).as_version_str() + # Get one shot input size limit for the current world size + # for the current device capability + max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES. \ + get(device_capability, {}). \ + get(world_size, None) + # Use one shot if no max size is specified + use_oneshot = max_one_shot_size is None or \ + current_tensor_size <= max_one_shot_size + + assert ( + _FI_WORKSPACE_TENSOR + is not None), "Flashinfer must be enabled when using flashinfer" + if norm_out is None: + norm_out = allreduce_in + residual_out = residual else: - allreduce_out = tensor_model_parallel_all_reduce(allreduce_in) - if (scale_factor is not None and scale_out is None - and fuse_rms_quant): - # Do fused rms norm static fp8 quant fused op - if norm_out is None: - torch.ops._C.fused_add_rms_norm_static_fp8_quant( - quant_out, allreduce_out, residual, rms_gamma, - scale_factor, rms_eps) - else: - torch.ops._C.rms_norm_static_fp8_quant( - quant_out, allreduce_out, rms_gamma, scale_factor, - rms_eps) - else: - if norm_out is None: - torch.ops._C.fused_add_rms_norm(allreduce_out, residual, - rms_gamma, rms_eps) - norm_out = allreduce_out - else: - torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, - rms_eps) - if scale_factor is not None: - if scale_out is not None: - torch.ops._C.scaled_fp4_quant(quant_out, norm_out, - scale_out, scale_factor) - else: - torch.ops._C.static_scaled_fp8_quant( - quant_out, norm_out, scale_factor) - if scale_factor is None or norm_out is not None: - # we need to return allreduce output - # in cases of non quant fused AR + RMS norm - # and fused AR + RMS norm + quant without fused add - allreduce_in.copy_(allreduce_out) + # return residual_out as allreduce_out with zeroed residual_in + # as flashinfer does not support rms_norm + # and allreduce_out together + residual_out = allreduce_in + # For the sizes that are smaller than the max size, + # we only use flashinfer one shot allreduce + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=allreduce_in, + token_num=allreduce_in.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + world_rank=world_rank, + world_size=world_size, + hidden_dim=allreduce_in.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + launch_with_pdl=launch_with_pdl, + use_oneshot=use_oneshot, + trigger_completion_at_end=trigger_completion_at_end, + fp32_acc=fp32_acc, + pattern_code=pattern_code, + allreduce_out=None, + quant_out=quant_out, + scale_out=scale_out, + # in vllm we only support swizzled layout + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + scale_factor=scale_factor, + ) def call_trtllm_fused_allreduce_norm_fake( allreduce_in: torch.Tensor, @@ -530,7 +496,6 @@ def call_trtllm_fused_allreduce_norm_fake( fp32_acc: bool, max_token_num: int, pattern_code: int, - fuse_rms_quant: bool, norm_out: Optional[torch.Tensor] = None, quant_out: Optional[torch.Tensor] = None, scale_out: Optional[torch.Tensor] = None, @@ -573,7 +538,6 @@ def __init__( self.fp32_acc = True self.use_oneshot = False self.max_token_num = max_token_num - self.fuse_rms_quant = fuse_rms_quant def get_trtllm_fused_allreduce_kwargs(self): return { @@ -583,7 +547,6 @@ def get_trtllm_fused_allreduce_kwargs(self): "trigger_completion_at_end": self.trigger_completion_at_end, "fp32_acc": self.fp32_acc, "max_token_num": self.max_token_num, - "fuse_rms_quant": self.fuse_rms_quant, } @@ -1089,24 +1052,27 @@ def __init__(self, config: VllmConfig): "Flashinfer is not installed or comm module not found, " "skipping allreduce fusion pass") return - # Check if the world size is supported - if self.tp_size not in _FI_MAX_SIZES: + max_size = config.compilation_config.\ + pass_config.flashinfer_max_size(self.tp_size) + if max_size is None: + # Flashinfer doesn't support current world size logger.warning( "Flashinfer allreduce fusion is not " "supported for world size %s", self.tp_size, ) return - max_num_token = min( - _FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE) // - (self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)), - config.compilation_config.pass_config. - fi_allreduce_fusion_max_token_num) + element_size = 4 if use_fp32_lamport else 2 + self.max_token_num = (max_size // (self.hidden_dim * element_size)) + # take the min to save workspace size and we'll never use more + # than max_num_batched_tokens anyways + self.max_token_num = min( + self.max_token_num, config.scheduler_config.max_num_batched_tokens) self.ipc_handles, workspace_tensor = ( flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( tp_rank=rank, tp_size=self.tp_size, - max_token_num=max_num_token, + max_token_num=self.max_token_num, hidden_dim=self.hidden_dim, group=self.group, use_fp32_lamport=use_fp32_lamport, @@ -1118,10 +1084,8 @@ def __init__(self, config: VllmConfig): rank=rank, world_size=self.tp_size, use_fp32_lamport=use_fp32_lamport, - max_token_num=max_num_token, - # fuse rms norm static fp8 quant fused op - # in fallback path, when we don't use flashinfer - fuse_rms_quant=config.compilation_config.pass_config.enable_fusion) + max_token_num=self.max_token_num, + ) self.register_patterns() @@ -1172,6 +1136,12 @@ def register_patterns(self): self.disabled = False + def is_applicable_for_range( + self, compile_range: Optional[tuple[int, int]]) -> bool: + if compile_range is None: + return False + return compile_range[1] - 1 <= self.max_token_num + def __call__(self, graph: fx.Graph): if self.disabled: return diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 7158fd685964..a38a2d2b4484 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -64,16 +64,17 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: Optional[int] = None, + compile_range: Optional[tuple[int, int]] = None, key: Optional[str] = None, ) -> tuple[Optional[Callable], Optional[Any]]: """ Compile the graph with the given example inputs and compiler config, - with a runtime shape. If the `runtime_shape` is None, it means + with a range. If the `compile_range` is None, it means the `example_inputs` have a dynamic shape. Otherwise, the - `runtime_shape` specifies the shape of the inputs. Right now we only - support one variable shape for all inputs, which is the batchsize - (number of tokens) during inference. + `compile_range` specifies the range of the inputs, + it could be concrete size, e.g. (4, 4). + Right now we only support one variable range of shapes for all inputs, + which is the batchsize (number of tokens) during inference. Dynamo will make sure `graph(*example_inputs)` is valid. @@ -98,7 +99,7 @@ def load(self, graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: Optional[int] = None) -> Callable: + compile_range: Optional[tuple[int, int]] = None) -> Callable: """ Load the compiled function from the handle. Raises an error if the handle is invalid. @@ -188,22 +189,25 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: Optional[int] = None, + compile_range: Optional[tuple[int, int]] = None, key: Optional[str] = None, ) -> tuple[Optional[Callable], Optional[Any]]: compilation_counter.num_inductor_compiles += 1 current_config = {} if compiler_config is not None: current_config.update(compiler_config) - set_inductor_config(current_config, runtime_shape) + set_inductor_config(current_config, compile_range) - if isinstance(runtime_shape, int): - dynamic_shapes = "from_example_inputs" + if isinstance(compile_range, tuple): + if compile_range[0] == compile_range[1]: + dynamic_shapes = "from_example_inputs" + else: + dynamic_shapes = "from_graph" else: dynamic_shapes = "from_tracing_context" from torch._inductor import standalone_compile - with pass_context(runtime_shape): + with pass_context(compile_range): compiled_graph = standalone_compile( graph, example_inputs, @@ -223,7 +227,7 @@ def load(self, graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: Optional[int] = None) -> Callable: + compile_range: Optional[tuple[int, int]] = None) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) assert isinstance(handle[1], str) @@ -283,7 +287,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: Optional[int] = None, + compile_range: Optional[tuple[int, int]] = None, key: Optional[str] = None, ) -> tuple[Optional[Callable], Optional[Any]]: compilation_counter.num_inductor_compiles += 1 @@ -296,7 +300,7 @@ def compile( current_config["fx_graph_cache"] = True current_config["fx_graph_remote_cache"] = False - set_inductor_config(current_config, runtime_shape) + set_inductor_config(current_config, compile_range) # inductor can inplace modify the graph, so we need to copy it # see https://github.com/pytorch/pytorch/issues/138980 @@ -433,7 +437,7 @@ def _get_shape_env() -> AlwaysHitShapeEnv: torch._functorch.config.patch( enable_remote_autograd_cache=False)) - with pass_context(runtime_shape): + with pass_context(compile_range): compiled_graph = compile_fx( graph, example_inputs, @@ -461,7 +465,7 @@ def load(self, graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: Optional[int] = None) -> Callable: + compile_range: Optional[tuple[int, int]] = None) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) assert isinstance(handle[1], str) @@ -547,10 +551,12 @@ def metrics_context(self) -> contextlib.AbstractContextManager: return contextlib.nullcontext() -def set_inductor_config(config, runtime_shape): - if isinstance(runtime_shape, int): - # for a specific batchsize, tuning triton kernel parameters +def set_inductor_config(config, compile_range): + if isinstance(compile_range, tuple): + # for a specific range of batchsizes, tuning triton kernel parameters # can be beneficial + #TODO(luka): max autotune only present with -O3, + # and this should live in config: https://github.com/vllm-project/vllm/issues/20283 config["max_autotune"] = True config["coordinate_descent_tuning"] = True @@ -563,7 +569,7 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: Optional[int] = None, + compile_range: Optional[tuple[int, int]] = None, key: Optional[str] = None, ) -> tuple[Optional[Callable], Optional[Any]]: compilation_counter.num_eager_compiles += 1 diff --git a/vllm/compilation/cuda_piecewise_backend.py b/vllm/compilation/cuda_piecewise_backend.py index ae26e9f1bf2b..f1d3998b7a15 100644 --- a/vllm/compilation/cuda_piecewise_backend.py +++ b/vllm/compilation/cuda_piecewise_backend.py @@ -6,7 +6,6 @@ import torch.fx as fx -import vllm.envs as envs from vllm.compilation.backends import VllmBackend from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.config import VllmConfig @@ -16,8 +15,8 @@ @dataclasses.dataclass -class ConcreteSizeEntry: - runtime_shape: int +class RangeEntry: + compile_range: tuple[int, int] compiled: bool = False runnable: Callable = None # type: ignore @@ -26,9 +25,7 @@ class PiecewiseBackend: def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, piecewise_compile_index: int, total_piecewise_compiles: int, - sym_shape_indices: list[int], - compiled_graph_for_general_shape: Callable, - vllm_backend: VllmBackend): + sym_shape_indices: list[int], vllm_backend: VllmBackend): """ The backend for piecewise compilation. It mainly handles the compilation of static shapes and @@ -51,67 +48,79 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, self.is_full_graph = total_piecewise_compiles == 1 - self.compile_sizes: set[int] = set( - self.compilation_config.compile_sizes) + self.compile_ranges = self.compilation_config.get_compile_ranges() + log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}" + logger.debug_once(log_string) - self.first_run_finished = False + self.is_in_range = lambda x, range: range[0] <= x < range[1] if range[ + 0] < range[1] else x == range[0] - self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa + self.first_run_finished = False self.sym_shape_indices = sym_shape_indices - self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" - # the entries for different shapes that we need to compile - self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} + # self.concrete_size_entries: dict[int, RangeEntry] = {} - # to_be_compiled_sizes tracks the remaining sizes to compile, + # the entries for ranges that we need to either + self.range_entries: dict[tuple[int, int], RangeEntry] = {} + + # to_be_compiled_ranges tracks the remaining ranges to compile, # and updates during the compilation process, so we need to copy it - self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() + self.to_be_compiled_ranges: set[tuple[int, + int]] = set(self.compile_ranges) # We only keep compilation management inside this class directly. - for shape in self.compile_sizes: - self.concrete_size_entries[shape] = ConcreteSizeEntry( - runtime_shape=shape, - runnable=self.compiled_graph_for_general_shape, - ) + for range in self.compile_ranges: + self.range_entries[range] = RangeEntry(compile_range=range, ) def check_for_ending_compilation(self): - if self.is_last_graph and not self.to_be_compiled_sizes: + if (self.is_last_graph and not self.to_be_compiled_ranges): # no specific sizes to compile # save the hash of the inductor graph for the next run self.vllm_backend.compiler_manager.save_to_file() end_monitoring_torch_compile(self.vllm_config) - def __call__(self, *args) -> Any: - if not self.first_run_finished: - self.first_run_finished = True - self.check_for_ending_compilation() - return self.compiled_graph_for_general_shape(*args) - - runtime_shape = args[self.sym_shape_indices[0]] - - if runtime_shape not in self.concrete_size_entries: - # we don't need to do anything for this shape - return self.compiled_graph_for_general_shape(*args) + def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, + args) -> Any: + if not range_entry.compiled: + range_entry.compiled = True + self.to_be_compiled_ranges.remove(range_entry.compile_range) - entry = self.concrete_size_entries[runtime_shape] - - if not entry.compiled: - entry.compiled = True - self.to_be_compiled_sizes.remove(runtime_shape) # args are real arguments - entry.runnable = self.vllm_backend.compiler_manager.compile( + range_entry.runnable = self.vllm_backend.compiler_manager.compile( self.graph, args, self.compilation_config.inductor_compile_config, self.compilation_config, graph_index=self.piecewise_compile_index, num_graphs=self.total_piecewise_compiles, - runtime_shape=runtime_shape) + compile_range=range_entry.compile_range) # finished compilations for all required shapes - if self.is_last_graph and not self.to_be_compiled_sizes: - self.check_for_ending_compilation() + self.check_for_ending_compilation() + + def __call__(self, *args) -> Any: + if not self.first_run_finished: + self.first_run_finished = True + + # Role of the general is taken by the last range + range_entry = self.range_entries[self.compile_ranges[-1]] + self._maybe_compile_for_range_entry(range_entry, args) + return range_entry.runnable(*args) + + runtime_shape = args[self.sym_shape_indices[0]] + + range_found = False + for range in self.compile_ranges: + if self.is_in_range(runtime_shape, range): + range_entry = self.range_entries[range] + range_found = True + break + assert range_found, \ + f"Shape out of considered range: {runtime_shape} " \ + "[1, max_num_batched_tokens]" + + self._maybe_compile_for_range_entry(range_entry, args) - return entry.runnable(*args) + return range_entry.runnable(*args) diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index e1b691df385d..7425271cd264 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -28,8 +28,8 @@ class PassContext: - def __init__(self, runtime_shape: Optional[int]): - self.runtime_shape = runtime_shape + def __init__(self, compile_range: Optional[tuple[int, int]]): + self.compile_range = compile_range def get_pass_context() -> PassContext: @@ -39,13 +39,13 @@ def get_pass_context() -> PassContext: @contextmanager -def pass_context(runtime_shape: Optional[int]): +def pass_context(compile_range: Optional[tuple[int, int]]): """A context manager that stores the current pass context, usually it is a list of sizes to specialize. """ global _pass_context prev_context = _pass_context - _pass_context = PassContext(runtime_shape) + _pass_context = PassContext(compile_range) try: yield finally: @@ -96,7 +96,8 @@ def hash_dict(dict_: dict[Any, Any]): encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).hexdigest() - def is_applicable_for_shape(self, shape: Optional[int]): + def is_applicable_for_range(self, compile_range: Optional[tuple[int, + int]]): return True diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 1b1cbe4fa12c..62ddacc5a102 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -43,9 +43,9 @@ def __init__(self): self.passes: list[VllmInductorPass] = [] def __call__(self, graph: fx.Graph): - shape = get_pass_context().runtime_shape + compile_range = get_pass_context().compile_range for pass_ in self.passes: - if pass_.is_applicable_for_shape(shape): + if pass_.is_applicable_for_range(compile_range): pass_(graph) # always run fix_functionalization last diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index 1758ed4c86d2..2b2d46405aa9 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -471,9 +471,12 @@ def __init__(self, config: VllmConfig): # and allow multiple values of epsilon. torch._inductor.pattern_matcher._seen_patterns.clear() - def is_applicable_for_shape(self, shape: Optional[int]) -> bool: + def is_applicable_for_range( + self, compile_range: Optional[tuple[int, int]]) -> bool: tp_size = get_tensor_model_parallel_world_size() - return shape is not None and shape % tp_size == 0 + return compile_range is not None and ( + compile_range[0] + == compile_range[1]) and (compile_range[1] % tp_size == 0) def __call__(self, graph: fx.Graph): self.begin() diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 941aff8919a9..927b77e860cc 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -3656,6 +3656,7 @@ def __post_init__(self): self._set_cudagraph_sizes() else: self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE + self._set_compile_ranges() if self.cache_config.cpu_offload_gb > 0 and \ self.compilation_config.level != CompilationLevel.NO_COMPILATION \ @@ -3873,6 +3874,46 @@ def _set_cudagraph_sizes(self): self.compilation_config.init_with_cudagraph_sizes( batch_size_capture_list) + def _set_compile_ranges(self): + """ + Set the compile ranges for the compilation config. + """ + compilation_config = self.compilation_config + computed_compile_ranges_split_points = [] + + # The upper bound of the compile ranges is the max_num_batched_tokens + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + if max_num_batched_tokens is not None: + # We add 1 because the bounds checks in the compiler are exclusive + # and we want to include the max_num_batched_tokens + # in the compile range + computed_compile_ranges_split_points.append( + max_num_batched_tokens + 1) + + # Add the compile ranges for flashinfer + if compilation_config.pass_config.enable_fi_allreduce_fusion: + tp_size = self.parallel_config.tensor_parallel_size + max_size = compilation_config.pass_config.flashinfer_max_size( + tp_size) + if max_size is not None: + max_token_num = max_size // ( + self.model_config.get_hidden_size() * + self.model_config.dtype.itemsize) + # We add 1 because the bounds checks in the compiler are + # exclusive and we want to include the max_token_num in the + # compile range + computed_compile_ranges_split_points.append(max_token_num + 1) + + if compilation_config.compile_ranges_split_points is not None: + for x in compilation_config.compile_ranges_split_points: + assert isinstance(x, int) + assert x > 0, f"Invalid compile range split point: {x}" + if (max_num_batched_tokens is not None + and x < max_num_batched_tokens and x > 1): + computed_compile_ranges_split_points.append(x) + compilation_config.compile_ranges_split_points = sorted( + computed_compile_ranges_split_points) # type: ignore + def recalculate_max_model_len(self, max_model_len: int): # Can only be called in try_verify_and_update_config model_config = self.model_config diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 677fb069bc07..516252791b0e 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -87,8 +87,65 @@ class PassConfig: """Whether to enable async TP.""" enable_fi_allreduce_fusion: bool = False """Whether to enable flashinfer allreduce fusion.""" - fi_allreduce_fusion_max_token_num: int = 16384 - """Max number of tokens to used in flashinfer allreduce fusion.""" + fi_allreduce_fusion_max_size_mb: dict[int, + float] = field(default_factory=dict) + """The thresholds of the communicated tensor sizes under which + vllm should use flashinfer fused allreduce. Specified as a + dictionary mapping each world size to the threshold in MB + { : } + Unspecified world sizes will fallback to + _FI_ALLREDUCE_MAX_INPUT_SIZES = { + "9.0": { + 2: 64 * MiB, # 64MB + 4: 2 * MiB, # 2MB + 8: 1 * MiB, # 1MB + }, + "10.0": { + 2: 64 * MiB, # 64MB + 4: 32 * MiB, # 32MB + 8: 1 * MiB, # 1MB + }, + }, where key is the device capability""" + + # TODO(luka) better pass enabling system. + + def flashinfer_max_size(self, world_size: int) -> Optional[int]: + """ + Returns the max communication size in bytes for flashinfer + allreduce fusion for the given world size. Falls back to + conservative defaults if the world size is not specified in config. + """ + + # import here to avoid circular dependencies + from vllm.platforms import current_platform + MiB = 1024 * 1024 + + # Max size of the input tensor per world size per device capability + # to use flashinfer fused allreduce + _FI_ALLREDUCE_MAX_INPUT_SIZES = { + "9.0": { + 2: 64 * MiB, # 64MB + 4: 2 * MiB, # 2MB + 8: 1 * MiB, # 1MB + }, + "10.0": { + 2: 64 * MiB, # 64MB + 4: 32 * MiB, # 32MB + 8: 1 * MiB, # 1MB + }, + } + + device_capability = current_platform.get_device_capability( + ).as_version_str() + max_sizes = _FI_ALLREDUCE_MAX_INPUT_SIZES.get(device_capability, {}) + max_sizes.update({ + k: int(v * MiB) + for k, v in self.fi_allreduce_fusion_max_size_mb.items() + }) + if world_size not in max_sizes: + # FlashInfer doesn't support other world sizes + return None + return max_sizes[world_size] # TODO(luka) better pass enabling system. @@ -137,6 +194,8 @@ class CompilationConfig: - Inductor compilation: - [`use_inductor`][vllm.config.CompilationConfig.use_inductor] - [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes] + - [`compile_ranges_split_points`] + [vllm.config.CompilationConfig.compile_ranges_split_points] - [`inductor_compile_config`] [vllm.config.CompilationConfig.inductor_compile_config] - [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes] @@ -212,6 +271,16 @@ class CompilationConfig: """Sizes to compile for inductor. In addition to integers, it also supports "cudagraph_capture_sizes" to specify the sizes for cudagraph capture.""" + compile_ranges_split_points: Optional[list[int]] = None + """Split points that represent compile ranges for inductor. + The compile ranges are + [1, split_points[0]), + [split_points[0], split_points[1]), ..., + [split_points[-1], max_num_batched_tokens + 1). + Compile sizes are also used single element ranges: + [compile_sizes[i], compile_sizes[i] + 1). + """ + inductor_compile_config: dict = field(default_factory=dict) """Additional configurations for inductor. - None: use default configurations.""" @@ -564,3 +633,26 @@ def set_splitting_ops_for_v1(self): def splitting_ops_contain_attention(self) -> bool: return self.splitting_ops is not None and all( op in self.splitting_ops for op in self._attention_ops) + + def get_compile_ranges(self) -> list[tuple[int, int]]: + """Get the compile ranges for the compilation config.""" + compile_ranges_split_points = self.compile_ranges_split_points + compile_ranges = [] + # max_num_batched_tokens + 1 + max_split_point = max(compile_ranges_split_points) + compile_sizes = set(self.compile_sizes) + split_points = sorted( + compile_sizes.union(set(self.compile_ranges_split_points))) + # filter out split points that are greater + # than max_num_batched_tokens + 1 + split_points = [x for x in split_points if x <= max_split_point] + for i, s in enumerate(split_points): + if i == 0: + compile_ranges.append((1, s)) + else: + compile_ranges.append((split_points[i - 1], s)) + if s in compile_sizes and s != 1: + compile_ranges.append((s, s)) + assert compile_ranges[-1][1] == max_split_point, \ + "Last compile range end should be max_split_point" + return sorted(compile_ranges)