diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py new file mode 100644 index 000000000000..af267f804ffa --- /dev/null +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -0,0 +1,92 @@ +import os +from typing import List + +import pytest + +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.activation import (GeluAndMul, + ReLUSquaredActivation, + SiluAndMul) +from vllm.model_executor.layers.layernorm import RMSNorm + + +# Registered subclass for test +@CustomOp.register("relu3") +class Relu3(ReLUSquaredActivation): + pass + + +@pytest.mark.parametrize( + "env, torch_level, ops_enabled, default_on", + [ + # Default values based on compile level + ("", 0, [True] * 4, True), + ("", 1, [True] * 4, True), + ("", 2, [True] * 4, True), # All by default + ("", 3, [False] * 4, False), + ("", 4, [False] * 4, False), # None by default + # Explicitly enabling/disabling + # + # Default: all + # + # All but SiluAndMul + ("+rms_norm,-silu_and_mul", 0, [1, 0, 1, 1], True), + # Only ReLU3 + ("none,-rms_norm,+relu3", 0, [0, 0, 0, 1], False), + # All but SiluAndMul + ("all,-silu_and_mul", 1, [1, 0, 1, 1], True), + # All but ReLU3 (even if ReLU2 is on) + ("-relu3,relu2", 1, [1, 1, 1, 0], True), + # GeluAndMul and SiluAndMul + ("none,-relu3,+gelu_and_mul,+silu_and_mul", 2, [0, 1, 1, 0], False), + # All but RMSNorm + ("-rms_norm", 2, [0, 1, 1, 1], True), + # + # Default: none + # + # Only ReLU3 + ("-silu_and_mul,+relu3", 3, [0, 0, 0, 1], False), + # All but RMSNorm + ("all,-rms_norm", 4, [0, 1, 1, 1], True), + ]) +def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int], + default_on: bool): + os.environ["VLLM_CUSTOM_OPS"] = env + os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level) + + # Reset default_on (computed once): + CustomOp.default_on.cache_clear() + + assert CustomOp.default_on() == default_on + + ops_enabled = [bool(x) for x in ops_enabled] + + assert RMSNorm(1024).enabled() == ops_enabled[0] + assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0] + + assert SiluAndMul().enabled() == ops_enabled[1] + assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1] + + assert GeluAndMul().enabled() == ops_enabled[2] + assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2] + + # If registered, subclasses should follow their own name + assert Relu3().enabled() == ops_enabled[3] + assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3] + + # Unregistered subclass + class SiluAndMul2(SiluAndMul): + pass + + # Subclasses should not require registration + assert SiluAndMul2().enabled() == SiluAndMul().enabled() + + +@pytest.mark.parametrize( + "env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"]) +def test_enabled_ops_invalid(env: str): + os.environ["VLLM_CUSTOM_OPS"] = env + CustomOp.default_on.cache_clear() + + with pytest.raises(AssertionError): + RMSNorm(1024).enabled() diff --git a/vllm/envs.py b/vllm/envs.py index 45a9999610f6..8bf86b300633 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -66,6 +66,7 @@ VLLM_SKIP_P2P_CHECK: bool = False VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1: bool = False VLLM_TORCH_COMPILE_LEVEL: int = 0 + VLLM_CUSTOM_OPS: List[str] = [] VLLM_DISABLED_KERNELS: List[str] = [] @@ -206,7 +207,17 @@ def get_default_config_root(): os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"), "VLLM_TORCH_COMPILE_LEVEL": lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")), - + # Fine-grained control over which custom ops to enable/disable. + # Use 'all' to enable all, 'none' to disable all. + # Also specify a list of custom op names to enable (prefixed with a '+'), + # or disable (prefixed with a '-'). + # Examples: + # - 'all,-op1' to enable all except op1 + # - 'none,+op1,+op2' to enable only op1 and op2 + # By default, all custom ops are enabled when running without Inductor + # and disabled when running with Inductor (compile_level >= Inductor). + "VLLM_CUSTOM_OPS": + lambda: os.environ.get("VLLM_CUSTOM_OPS", "").replace(" ", "").split(","), # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index d0e90245ad01..549be116772c 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -1,14 +1,24 @@ +from functools import lru_cache +from typing import Dict, Type + import torch.nn as nn import vllm.envs as envs from vllm.compilation.levels import CompilationLevel +from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import is_cpu, is_hip, is_xpu +from vllm.utils import is_cpu, is_hip, is_xpu, print_warning_once + +logger = init_logger(__name__) class CustomOp(nn.Module): + """ + Base class for custom ops. + Dispatches the forward method to the appropriate backend. + """ - def __init__(self, *args, **kwargs): + def __init__(self): super().__init__() self._forward_method = self.dispatch_forward() @@ -17,7 +27,6 @@ def forward(self, *args, **kwargs): def forward_native(self, *args, **kwargs): """PyTorch-native implementation of the forward method. - This method is optional. If implemented, it can be used with compilers such as torch.compile or PyTorch XLA. Also, it can be used for testing purposes. @@ -56,7 +65,11 @@ def dispatch_forward(self): # NOTE(woosuk): Here we assume that vLLM was built for only one # specific backend. Currently, we do not support dynamic dispatching. - if envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.INDUCTOR: + enabled = self.enabled() + logger.debug("custom op %s %s", self.__class__.name, + "enabled" if enabled else "disabled") + + if not enabled: return self.forward_native if is_hip(): @@ -69,3 +82,50 @@ def dispatch_forward(self): return self.forward_xpu else: return self.forward_cuda + + @classmethod + def enabled(cls) -> bool: + # if no name, then it was not registered + if not hasattr(cls, "name"): + print_warning_once( + f"Custom op {cls.__name__} was not registered, " + f"which means it won't appear in the op registry. " + f"It will be enabled/disabled based on the global settings.") + return CustomOp.default_on() + + enabled = f"+{cls.name}" in envs.VLLM_CUSTOM_OPS + disabled = f"-{cls.name}" in envs.VLLM_CUSTOM_OPS + assert not (enabled + and disabled), f"Cannot enable and disable {cls.name}" + + return (CustomOp.default_on() or enabled) and not disabled + + # On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR + # Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence. + @staticmethod + @lru_cache() + def default_on() -> bool: + count_none = envs.VLLM_CUSTOM_OPS.count("none") + count_all = envs.VLLM_CUSTOM_OPS.count("all") + assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" + return envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR and \ + not count_none > 0 or count_all > 0 + + # Dictionary of all custom ops (classes, indexed by registered name). + # To check if an op with a name is enabled, call .enabled() on the class. + # Examples: + # - MyOp.enabled() + # - op_registry["my_op"].enabled() + op_registry: Dict[str, Type['CustomOp']] = {} + + # Decorator to register custom ops. + @classmethod + def register(cls, name: str): + + def decorator(op_cls): + assert name not in cls.op_registry, f"Duplicate op name: {name}" + op_cls.name = name + cls.op_registry[name] = op_cls + return op_cls + + return decorator diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index f2ea53cad9f2..cf99306c9cae 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -11,11 +11,13 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.utils import set_weight_attrs +from vllm.utils import LazyDict +@CustomOp.register("fatrelu_and_mul") class FatreluAndMul(CustomOp): """An activation function for FATReLU. - + The function computes x -> FATReLU(x[:d]) * x[d:] where d = x.shape[-1] // 2. This is used in openbmb/MiniCPM-S-1B-sft. @@ -40,6 +42,7 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: return self.forward_native(x) +@CustomOp.register("silu_and_mul") class SiluAndMul(CustomOp): """An activation function for SwiGLU. @@ -74,6 +77,7 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: return out +@CustomOp.register("gelu_and_mul") class GeluAndMul(CustomOp): """An activation function for GeGLU. @@ -123,6 +127,7 @@ def extra_repr(self) -> str: return f'approximate={repr(self.approximate)}' +@CustomOp.register("gelu_new") class NewGELU(CustomOp): def forward_native(self, x: torch.Tensor) -> torch.Tensor: @@ -144,6 +149,7 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: return ops.gelu_new(x) +@CustomOp.register("gelu_fast") class FastGELU(CustomOp): def forward_native(self, x: torch.Tensor) -> torch.Tensor: @@ -164,8 +170,8 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: return ops.gelu_fast(x) +@CustomOp.register("quick_gelu") class QuickGELU(CustomOp): - # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90 def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" @@ -189,6 +195,7 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: +@CustomOp.register("relu2") class ReLUSquaredActivation(CustomOp): """ Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2 @@ -244,15 +251,22 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param_data.copy_(loaded_weight) -_ACTIVATION_REGISTRY = { - "gelu": nn.GELU(), - "gelu_fast": FastGELU(), - "gelu_new": NewGELU(), - "gelu_pytorch_tanh": nn.GELU(approximate="tanh"), - "relu": nn.ReLU(), - "relu2": ReLUSquaredActivation(), - "quick_gelu": QuickGELU(), -} +_ACTIVATION_REGISTRY = LazyDict({ + "gelu": + lambda: nn.GELU(), + "gelu_fast": + lambda: FastGELU(), + "gelu_new": + lambda: NewGELU(), + "gelu_pytorch_tanh": + lambda: nn.GELU(approximate="tanh"), + "relu": + lambda: nn.ReLU(), + "relu2": + lambda: ReLUSquaredActivation(), + "quick_gelu": + lambda: QuickGELU(), +}) def get_act_fn( diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index bce740d0db75..8dd36620e3fa 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -37,13 +37,13 @@ def apply(self, layer: torch.nn.Module, x: torch.Tensor, raise NotImplementedError +@CustomOp.register("unquantized_fused_moe") class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size: int, params_dtype: torch.dtype, **extra_weight_attrs): - # Fused gate_up_proj (column parallel) w13_weight = torch.nn.Parameter(torch.empty(num_experts, 2 * intermediate_size, @@ -74,7 +74,6 @@ def apply( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None ) -> torch.Tensor: - return self.forward(x=x, layer=layer, router_logits=router_logits, @@ -97,7 +96,6 @@ def forward_cuda( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts) @@ -134,7 +132,6 @@ def forward_tpu( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe assert not use_grouped_topk assert num_expert_group is None diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index d55f86056d17..10fae84dab72 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -7,6 +7,7 @@ from vllm.model_executor.custom_op import CustomOp +@CustomOp.register("rms_norm") class RMSNorm(CustomOp): """Root mean square normalization. @@ -122,6 +123,7 @@ def extra_repr(self) -> str: return s +@CustomOp.register("gemma_rms_norm") class GemmaRMSNorm(CustomOp): """RMS normalization for Gemma. diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 2ed44e2093bb..2158ad333967 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -72,6 +72,7 @@ def _apply_rotary_emb( return torch.stack((o1, o2), dim=-1).flatten(-2) +@CustomOp.register("rotary_embedding") class RotaryEmbedding(CustomOp): """Original rotary positional embedding.""" @@ -468,7 +469,7 @@ def __init__( self.long_factor = long_factor scale = self.max_position_embeddings / \ - self.original_max_position_embeddings + self.original_max_position_embeddings if scale <= 1.0: scaling_factor = 1.0 else: diff --git a/vllm/utils.py b/vllm/utils.py index 8debae52b288..07769da3c86d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -17,6 +17,7 @@ import warnings import weakref from asyncio import FIRST_COMPLETED, ensure_future +from collections.abc import Mapping from functools import lru_cache, partial, wraps from platform import uname from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic, @@ -1442,3 +1443,24 @@ def dec(self, num=1): @property def value(self): return self._value + + +# Adapted from: https://stackoverflow.com/a/47212782/5082708 +class LazyDict(Mapping, Generic[T]): + + def __init__(self, factory: Dict[str, Callable[[], T]]): + self._factory = factory + self._dict: Dict[str, T] = {} + + def __getitem__(self, key) -> T: + if key not in self._dict: + if key not in self._factory: + raise KeyError(key) + self._dict[key] = self._factory[key]() + return self._dict[key] + + def __iter__(self): + return iter(self._factory) + + def __len__(self): + return len(self._factory)