From 94118ed5babb1c4f3ef2002b9f1c95a92dd114f7 Mon Sep 17 00:00:00 2001 From: luka Date: Fri, 11 Oct 2024 18:23:58 -0400 Subject: [PATCH 01/10] add custom op enabling mechanism (with test) --- .../model_executor/test_enabled_custom_ops.py | 60 ++++++++++++++++ vllm/compilation/backends.py | 4 +- vllm/envs.py | 5 +- vllm/model_executor/custom_op.py | 71 ++++++++++++++++++- vllm/model_executor/layers/activation.py | 17 ++++- vllm/model_executor/layers/fused_moe/layer.py | 3 + vllm/model_executor/layers/layernorm.py | 4 +- .../model_executor/layers/rotary_embedding.py | 2 +- 8 files changed, 156 insertions(+), 10 deletions(-) create mode 100644 tests/model_executor/test_enabled_custom_ops.py diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py new file mode 100644 index 000000000000..2ef653a74987 --- /dev/null +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -0,0 +1,60 @@ +import os + +import pytest + +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.activation import (GeluAndMul, + ReLUSquaredActivation, + SiluAndMul) +from vllm.model_executor.layers.layernorm import RMSNorm + + +@pytest.mark.parametrize( + "env, torch_level, ops_enabled", + [ + # Default values based on compile level + ("", 0, [True] * 4), + ("", 1, [True] * 4), + ("", 2, [True] * 4), # All by default + ("", 3, [False] * 4), + ("", 4, [False] * 4), # None by default + # Explicitly enabling/disabling + # Default: all + ("rms_norm,-silu_and_mul", 0, [True, False, True, True] + ), # All but SiluAndMul + ("none,-rms_norm,relu2", 0, [False, False, False, True + ]), # Only ReLUSquaredActivation + ("all,-silu_and_mul", 1, [True, False, True, True + ]), # All but SiluAndMul + ("-relu2", 1, [True, True, True, False + ]), # All but ReLUSquaredActivation + ("none,-relu2,gelu_and_mul,silu_and_mul", 2, + [False, True, True, False]), # GeluAndMul and SiluAndMul + ("-rms_norm", 2, [False, True, True, True]), # All but RMSNorm + # Default: none + ("-silu_and_mul,relu2", 3, [False, False, False, True] + ), # Only ReLUSquaredActivation + ("all,-rms_norm", 4, [False, True, True, True]), # All but RMSNorm + ]) +def test_enabled_ops(env: str, torch_level: int, ops_enabled): + os.environ["VLLM_ENABLE_CUSTOM_OPS"] = env + os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level) + + # Enabling happens on import with this method + CustomOp._init_enabled_ops() + + assert RMSNorm(1024)._enabled() == ops_enabled[0] + assert SiluAndMul()._enabled() == ops_enabled[1] + assert GeluAndMul()._enabled() == ops_enabled[2] + assert ReLUSquaredActivation()._enabled() == ops_enabled[3] + + +@pytest.mark.parametrize("env", [ + "all,none", "-none", "-all", "all,rms_norm,all", "rms_norm,-rms_norm", + "RmsNorm" +]) +def test_enabled_ops_invalid(env: str): + os.environ["VLLM_ENABLE_CUSTOM_OPS"] = env + + with pytest.raises(AssertionError): + CustomOp._init_enabled_ops() diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 4780358cea51..5f6b172f54e5 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -153,8 +153,8 @@ def fix_functionalization(graph: fx.Graph): graph.erase_node(node) # debug code, if we want to see the graph after the transformation - # with open("after.py", "w") as f: - # print(graph.python_code(root_module="self", verbose=True).src, file=f) + with open("after.py", "w") as f: + print(graph.python_code(root_module="self", verbose=True).src, file=f) def wrap_inductor(graph, example_inputs, additional_inductor_config): diff --git a/vllm/envs.py b/vllm/envs.py index 8b541e5b78c0..efd2eb23b9fd 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -66,6 +66,7 @@ VLLM_SKIP_P2P_CHECK: bool = False VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1: bool = False VLLM_TORCH_COMPILE_LEVEL: int = 0 + VLLM_ENABLE_CUSTOM_OPS: List[str] = [] def get_default_cache_root(): @@ -205,7 +206,9 @@ def get_default_config_root(): os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"), "VLLM_TORCH_COMPILE_LEVEL": lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")), - + "VLLM_ENABLE_CUSTOM_OPS": + lambda: os.environ.get("VLLM_ENABLE_CUSTOM_OPS", "").replace(" ", ""). + split(","), # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index d0e90245ad01..d414194a0a23 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -1,3 +1,6 @@ +import re +from typing import Set, Tuple + import torch.nn as nn import vllm.envs as envs @@ -7,9 +10,14 @@ class CustomOp(nn.Module): + """ + Base class for custom ops. + Dispatches the forward method to the appropriate backend. + """ - def __init__(self, *args, **kwargs): + def __init__(self, name: str): super().__init__() + self.name = name self._forward_method = self.dispatch_forward() def forward(self, *args, **kwargs): @@ -17,7 +25,6 @@ def forward(self, *args, **kwargs): def forward_native(self, *args, **kwargs): """PyTorch-native implementation of the forward method. - This method is optional. If implemented, it can be used with compilers such as torch.compile or PyTorch XLA. Also, it can be used for testing purposes. @@ -56,7 +63,7 @@ def dispatch_forward(self): # NOTE(woosuk): Here we assume that vLLM was built for only one # specific backend. Currently, we do not support dynamic dispatching. - if envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.INDUCTOR: + if not self._enabled(): return self.forward_native if is_hip(): @@ -69,3 +76,61 @@ def dispatch_forward(self): return self.forward_xpu else: return self.forward_cuda + + @staticmethod + def _get_enabled_ops() -> Tuple[bool, Set[str], Set[str]]: + """ + Parse the VLLM_ENABLE_CUSTOM_OPS environment variable to determine + which custom ops are enabled. By default, custom ops are enabled + if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR. + Specifying 'all' or 'none' will override this default. + + :return: A tuple of (default_on, enabled_ops, disabled_ops) + """ + + # filter empty strings + env_ops = list( + filter(lambda x: len(x) > 0, envs.VLLM_ENABLE_CUSTOM_OPS)) + + use_all, use_none = env_ops.count("all"), env_ops.count("none") + assert use_all + use_none <= 1, \ + "Cannot specify both 'all' and 'none' in VLLM_ENABLE_CUSTOM_OPS" + + # On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR + default_on = envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR + + # override the default if 'all' or 'none' is specified + default_on = default_on and not bool(use_none) or bool(use_all) + enabled_ops, disabled_ops = set(), set() + + for op in env_ops: + if op == "all" or op == "none": + continue + + assert re.match(r"^-?[a-z0-9_]+$", + op), f"Invalid custom op name: '{op}'" + + if op.startswith("-"): + assert op[1:] not in {"all", "none"}, \ + f"Cannot disable all or none: '{op}'" + disabled_ops.add(op[1:]) + else: + enabled_ops.add(op) + + assert (len(enabled_ops & disabled_ops) == 0 + ), "Cannot enable and disable the same custom ops: " + str( + enabled_ops & disabled_ops) + + return default_on, enabled_ops, disabled_ops + + @classmethod + def _init_enabled_ops(cls): + cls.default_on, cls.enabled_ops, cls.disabled_ops = ( + cls._get_enabled_ops()) + + def _enabled(self) -> bool: + return ((CustomOp.default_on or self.name in CustomOp.enabled_ops) + and self.name not in CustomOp.disabled_ops) + + +CustomOp._init_enabled_ops() diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 43056786d35c..e52cf5c0287d 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -23,6 +23,9 @@ class SiluAndMul(CustomOp): return: (num_tokens, d) or (batch_size, seq_len, d) """ + def __init__(self): + super().__init__("silu_and_mul") + def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" d = x.shape[-1] // 2 @@ -58,7 +61,7 @@ class GeluAndMul(CustomOp): """ def __init__(self, approximate: str = "none"): - super().__init__() + super().__init__("gelu_and_mul") self.approximate = approximate if approximate not in ("none", "tanh"): raise ValueError(f"Unknown approximate mode: {approximate}") @@ -98,6 +101,9 @@ def extra_repr(self) -> str: class NewGELU(CustomOp): + def __init__(self): + super().__init__("gelu_new") + def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" c = math.sqrt(2.0 / math.pi) @@ -119,6 +125,9 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: class FastGELU(CustomOp): + def __init__(self): + super().__init__("gelu_fast") + def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * @@ -139,6 +148,9 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: class QuickGELU(CustomOp): + def __init__(self): + super().__init__("quick_gelu") + # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90 def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" @@ -167,6 +179,9 @@ class ReLUSquaredActivation(CustomOp): Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2 """ + def __init__(self): + super().__init__("relu2") + def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" return torch.square(F.relu(x)) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index bce740d0db75..63e07ff72303 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -40,6 +40,9 @@ def apply(self, layer: torch.nn.Module, x: torch.Tensor, class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" + def __init__(self): + super().__init__("unquantized_fused_moe") + def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size: int, params_dtype: torch.dtype, **extra_weight_attrs): diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index d55f86056d17..9d832632f9dd 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -20,7 +20,7 @@ def __init__( eps: float = 1e-6, var_hidden_size: Optional[int] = None, ) -> None: - super().__init__() + super().__init__("rms_norm") self.hidden_size = hidden_size self.variance_epsilon = eps @@ -135,7 +135,7 @@ def __init__( hidden_size: int, eps: float = 1e-6, ) -> None: - super().__init__() + super().__init__("gemma_rms_norm") self.weight = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index d4e9ed87ed54..aaa6407dc585 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -84,7 +84,7 @@ def __init__( is_neox_style: bool, dtype: torch.dtype, ) -> None: - super().__init__() + super().__init__("rotary_embedding") self.head_size = head_size self.rotary_dim = rotary_dim self.max_position_embeddings = max_position_embeddings From 88eef3f1f50b9ea6c3a4ca1dabe9004f2c042678 Mon Sep 17 00:00:00 2001 From: luka Date: Fri, 11 Oct 2024 18:30:48 -0400 Subject: [PATCH 02/10] Add a comment --- vllm/envs.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/vllm/envs.py b/vllm/envs.py index efd2eb23b9fd..6100de8d9c68 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -206,6 +206,15 @@ def get_default_config_root(): os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"), "VLLM_TORCH_COMPILE_LEVEL": lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")), + # Fine-grained control over which custom ops to enable. + # Use 'all' to enable all, 'none' to disable all. + # Also specify a list of custom op names to enable, + # or disable if the name is prefixed with a '-'. + # Examples: + # - 'all,-op1' to enable all except op1 + # - 'none,op1,op2' to enable only op1 and op2 + # By default, all custom ops are enabled when running without Inductor + # and disabled when running with Inductor (compile_level >= Inductor). "VLLM_ENABLE_CUSTOM_OPS": lambda: os.environ.get("VLLM_ENABLE_CUSTOM_OPS", "").replace(" ", ""). split(","), From 14a31053692957fa211de8af12a6a46441f5847e Mon Sep 17 00:00:00 2001 From: luka Date: Fri, 11 Oct 2024 18:34:54 -0400 Subject: [PATCH 03/10] revert comment --- vllm/compilation/backends.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 5f6b172f54e5..4780358cea51 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -153,8 +153,8 @@ def fix_functionalization(graph: fx.Graph): graph.erase_node(node) # debug code, if we want to see the graph after the transformation - with open("after.py", "w") as f: - print(graph.python_code(root_module="self", verbose=True).src, file=f) + # with open("after.py", "w") as f: + # print(graph.python_code(root_module="self", verbose=True).src, file=f) def wrap_inductor(graph, example_inputs, additional_inductor_config): From 69e5444658b02ca97293d6e8293416dc9b1b497e Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 15 Oct 2024 16:32:41 -0400 Subject: [PATCH 04/10] PR comments: - removed complex initialization - name -> class property --- .../model_executor/test_enabled_custom_ops.py | 57 +++++++------- vllm/envs.py | 7 +- vllm/model_executor/custom_op.py | 76 +++++-------------- vllm/model_executor/layers/activation.py | 22 ++---- vllm/model_executor/layers/fused_moe/layer.py | 8 +- vllm/model_executor/layers/layernorm.py | 6 +- .../model_executor/layers/rotary_embedding.py | 3 +- vllm/model_executor/models/ultravox.py | 1 + 8 files changed, 63 insertions(+), 117 deletions(-) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 2ef653a74987..8857e08d806b 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -1,4 +1,5 @@ import os +from typing import List import pytest @@ -10,38 +11,35 @@ @pytest.mark.parametrize( - "env, torch_level, ops_enabled", + "env, torch_level, ops_enabled, default_on", [ # Default values based on compile level - ("", 0, [True] * 4), - ("", 1, [True] * 4), - ("", 2, [True] * 4), # All by default - ("", 3, [False] * 4), - ("", 4, [False] * 4), # None by default + ("", 0, [True] * 4, True), + ("", 1, [True] * 4, True), + ("", 2, [True] * 4, True), # All by default + ("", 3, [False] * 4, False), + ("", 4, [False] * 4, False), # None by default # Explicitly enabling/disabling # Default: all - ("rms_norm,-silu_and_mul", 0, [True, False, True, True] - ), # All but SiluAndMul - ("none,-rms_norm,relu2", 0, [False, False, False, True - ]), # Only ReLUSquaredActivation - ("all,-silu_and_mul", 1, [True, False, True, True - ]), # All but SiluAndMul - ("-relu2", 1, [True, True, True, False - ]), # All but ReLUSquaredActivation - ("none,-relu2,gelu_and_mul,silu_and_mul", 2, - [False, True, True, False]), # GeluAndMul and SiluAndMul - ("-rms_norm", 2, [False, True, True, True]), # All but RMSNorm + ("+rms_norm,-silu_and_mul", 0, [True, False, True, True], True), # All but SiluAndMul + ("none,-rms_norm,+relu2", 0, [False, False, False, True], False), # Only ReLUSquaredActivation + ("all,-silu_and_mul", 1, [True, False, True, True], True), # All but SiluAndMul + ("-relu2", 1, [True, True, True, False], True), # All but ReLUSquaredActivation + ("none,-relu2,+gelu_and_mul,+silu_and_mul", 2,[False, True, True, False], False), # GeluAndMul and SiluAndMul + ("-rms_norm", 2, [False, True, True, True], True), # All but RMSNorm # Default: none - ("-silu_and_mul,relu2", 3, [False, False, False, True] - ), # Only ReLUSquaredActivation - ("all,-rms_norm", 4, [False, True, True, True]), # All but RMSNorm + ("-silu_and_mul,+relu2", 3, [False, False, False, True], False), # Only ReLUSquaredActivation + ("all,-rms_norm", 4, [False, True, True, True], True), # All but RMSNorm ]) -def test_enabled_ops(env: str, torch_level: int, ops_enabled): - os.environ["VLLM_ENABLE_CUSTOM_OPS"] = env +def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[bool], + default_on: bool): + os.environ["VLLM_CUSTOM_OPS"] = env os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level) - # Enabling happens on import with this method - CustomOp._init_enabled_ops() + # Reset default_on (computed once): + CustomOp.default_on.cache_clear() + + assert CustomOp.default_on() == default_on assert RMSNorm(1024)._enabled() == ops_enabled[0] assert SiluAndMul()._enabled() == ops_enabled[1] @@ -49,12 +47,11 @@ def test_enabled_ops(env: str, torch_level: int, ops_enabled): assert ReLUSquaredActivation()._enabled() == ops_enabled[3] -@pytest.mark.parametrize("env", [ - "all,none", "-none", "-all", "all,rms_norm,all", "rms_norm,-rms_norm", - "RmsNorm" -]) +@pytest.mark.parametrize( + "env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"]) def test_enabled_ops_invalid(env: str): - os.environ["VLLM_ENABLE_CUSTOM_OPS"] = env + os.environ["VLLM_CUSTOM_OPS"] = env + CustomOp.default_on.cache_clear() with pytest.raises(AssertionError): - CustomOp._init_enabled_ops() + RMSNorm(1024)._enabled() diff --git a/vllm/envs.py b/vllm/envs.py index 6100de8d9c68..e88a96f479ac 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -66,7 +66,7 @@ VLLM_SKIP_P2P_CHECK: bool = False VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1: bool = False VLLM_TORCH_COMPILE_LEVEL: int = 0 - VLLM_ENABLE_CUSTOM_OPS: List[str] = [] + VLLM_CUSTOM_OPS: List[str] = [] def get_default_cache_root(): @@ -215,9 +215,8 @@ def get_default_config_root(): # - 'none,op1,op2' to enable only op1 and op2 # By default, all custom ops are enabled when running without Inductor # and disabled when running with Inductor (compile_level >= Inductor). - "VLLM_ENABLE_CUSTOM_OPS": - lambda: os.environ.get("VLLM_ENABLE_CUSTOM_OPS", "").replace(" ", ""). - split(","), + "VLLM_CUSTOM_OPS": + lambda: os.environ.get("VLLM_CUSTOM_OPS", "").replace(" ", "").split(","), # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index d414194a0a23..0d62e8eabca7 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -1,5 +1,4 @@ -import re -from typing import Set, Tuple +from functools import lru_cache import torch.nn as nn @@ -15,9 +14,8 @@ class CustomOp(nn.Module): Dispatches the forward method to the appropriate backend. """ - def __init__(self, name: str): + def __init__(self): super().__init__() - self.name = name self._forward_method = self.dispatch_forward() def forward(self, *args, **kwargs): @@ -77,60 +75,22 @@ def dispatch_forward(self): else: return self.forward_cuda - @staticmethod - def _get_enabled_ops() -> Tuple[bool, Set[str], Set[str]]: - """ - Parse the VLLM_ENABLE_CUSTOM_OPS environment variable to determine - which custom ops are enabled. By default, custom ops are enabled - if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR. - Specifying 'all' or 'none' will override this default. - - :return: A tuple of (default_on, enabled_ops, disabled_ops) - """ - - # filter empty strings - env_ops = list( - filter(lambda x: len(x) > 0, envs.VLLM_ENABLE_CUSTOM_OPS)) - - use_all, use_none = env_ops.count("all"), env_ops.count("none") - assert use_all + use_none <= 1, \ - "Cannot specify both 'all' and 'none' in VLLM_ENABLE_CUSTOM_OPS" - - # On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR - default_on = envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR - - # override the default if 'all' or 'none' is specified - default_on = default_on and not bool(use_none) or bool(use_all) - enabled_ops, disabled_ops = set(), set() - - for op in env_ops: - if op == "all" or op == "none": - continue - - assert re.match(r"^-?[a-z0-9_]+$", - op), f"Invalid custom op name: '{op}'" - - if op.startswith("-"): - assert op[1:] not in {"all", "none"}, \ - f"Cannot disable all or none: '{op}'" - disabled_ops.add(op[1:]) - else: - enabled_ops.add(op) - - assert (len(enabled_ops & disabled_ops) == 0 - ), "Cannot enable and disable the same custom ops: " + str( - enabled_ops & disabled_ops) - - return default_on, enabled_ops, disabled_ops - @classmethod - def _init_enabled_ops(cls): - cls.default_on, cls.enabled_ops, cls.disabled_ops = ( - cls._get_enabled_ops()) + def _enabled(cls) -> bool: + enabled = f"+{cls.name}" in envs.VLLM_CUSTOM_OPS + disabled = f"-{cls.name}" in envs.VLLM_CUSTOM_OPS + assert not (enabled + and disabled), f"Cannot enable and disable {cls.name}" - def _enabled(self) -> bool: - return ((CustomOp.default_on or self.name in CustomOp.enabled_ops) - and self.name not in CustomOp.disabled_ops) + return (CustomOp.default_on() or enabled) and not disabled - -CustomOp._init_enabled_ops() + # On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR + # Specifying 'all' or 'none' will override this default. + @staticmethod + @lru_cache() + def default_on() -> bool: + count_none = envs.VLLM_CUSTOM_OPS.count("none") + count_all = envs.VLLM_CUSTOM_OPS.count("all") + assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" + return envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR and \ + not count_none > 0 or count_all > 0 diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index e52cf5c0287d..ed3217da471a 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -23,8 +23,7 @@ class SiluAndMul(CustomOp): return: (num_tokens, d) or (batch_size, seq_len, d) """ - def __init__(self): - super().__init__("silu_and_mul") + name = "silu_and_mul" def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" @@ -59,9 +58,10 @@ class GeluAndMul(CustomOp): x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d) return: (batch_size, seq_len, d) or (num_tokens, d) """ + name = "gelu_and_mul" def __init__(self, approximate: str = "none"): - super().__init__("gelu_and_mul") + super().__init__() self.approximate = approximate if approximate not in ("none", "tanh"): raise ValueError(f"Unknown approximate mode: {approximate}") @@ -100,9 +100,7 @@ def extra_repr(self) -> str: class NewGELU(CustomOp): - - def __init__(self): - super().__init__("gelu_new") + name = "gelu_new" def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" @@ -124,9 +122,7 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: class FastGELU(CustomOp): - - def __init__(self): - super().__init__("gelu_fast") + name = "gelu_fast" def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" @@ -147,9 +143,7 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: class QuickGELU(CustomOp): - - def __init__(self): - super().__init__("quick_gelu") + name = "quick_gelu" # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90 def forward_native(self, x: torch.Tensor) -> torch.Tensor: @@ -178,9 +172,7 @@ class ReLUSquaredActivation(CustomOp): """ Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2 """ - - def __init__(self): - super().__init__("relu2") + name = "relu2" def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 63e07ff72303..ba2fe2d0b97a 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -39,14 +39,11 @@ def apply(self, layer: torch.nn.Module, x: torch.Tensor, class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" - - def __init__(self): - super().__init__("unquantized_fused_moe") + name = "unquantized_fused_moe" def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size: int, params_dtype: torch.dtype, **extra_weight_attrs): - # Fused gate_up_proj (column parallel) w13_weight = torch.nn.Parameter(torch.empty(num_experts, 2 * intermediate_size, @@ -77,7 +74,6 @@ def apply( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None ) -> torch.Tensor: - return self.forward(x=x, layer=layer, router_logits=router_logits, @@ -100,7 +96,6 @@ def forward_cuda( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts) @@ -137,7 +132,6 @@ def forward_tpu( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe assert not use_grouped_topk assert num_expert_group is None diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 9d832632f9dd..552a81311af3 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -13,6 +13,7 @@ class RMSNorm(CustomOp): Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. Refer to https://arxiv.org/abs/1910.07467 """ + name = "rms_norm" def __init__( self, @@ -20,7 +21,7 @@ def __init__( eps: float = 1e-6, var_hidden_size: Optional[int] = None, ) -> None: - super().__init__("rms_norm") + super().__init__() self.hidden_size = hidden_size self.variance_epsilon = eps @@ -129,13 +130,14 @@ class GemmaRMSNorm(CustomOp): 1. x * (1 + w) instead of x * w. 2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w. """ + name = "gemma_rms_norm" def __init__( self, hidden_size: int, eps: float = 1e-6, ) -> None: - super().__init__("gemma_rms_norm") + super().__init__() self.weight = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index aaa6407dc585..f7f4570fa4f7 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -74,6 +74,7 @@ def _apply_rotary_emb( class RotaryEmbedding(CustomOp): """Original rotary positional embedding.""" + name = "rotary_embedding" def __init__( self, @@ -84,7 +85,7 @@ def __init__( is_neox_style: bool, dtype: torch.dtype, ) -> None: - super().__init__("rotary_embedding") + super().__init__() self.head_size = head_size self.rotary_dim = rotary_dim self.max_position_embeddings = max_position_embeddings diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index e162e3af008e..063f4d22e20a 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -231,6 +231,7 @@ def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor: class FlippedSiluAndMul(SiluAndMul): """Ultravox is trained with SwiGLU with flipped halves.""" + name = "flipped_silu_and_mul" def forward(self, x: torch.Tensor): a, b = x.chunk(2, dim=-1) From 329f4afeeeab313f0c06659007ec9b819351f0d2 Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 15 Oct 2024 16:36:21 -0400 Subject: [PATCH 05/10] Reformat test --- .../model_executor/test_enabled_custom_ops.py | 38 ++++++++++++------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 8857e08d806b..edc8b2194bd7 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -20,18 +20,30 @@ ("", 3, [False] * 4, False), ("", 4, [False] * 4, False), # None by default # Explicitly enabling/disabling + # # Default: all - ("+rms_norm,-silu_and_mul", 0, [True, False, True, True], True), # All but SiluAndMul - ("none,-rms_norm,+relu2", 0, [False, False, False, True], False), # Only ReLUSquaredActivation - ("all,-silu_and_mul", 1, [True, False, True, True], True), # All but SiluAndMul - ("-relu2", 1, [True, True, True, False], True), # All but ReLUSquaredActivation - ("none,-relu2,+gelu_and_mul,+silu_and_mul", 2,[False, True, True, False], False), # GeluAndMul and SiluAndMul - ("-rms_norm", 2, [False, True, True, True], True), # All but RMSNorm + # + # All but SiluAndMul + ("+rms_norm,-silu_and_mul", 0, [1, 0, 1, 1], True), + # Only ReLUSquaredActivation + ("none,-rms_norm,+relu2", 0, [0, 0, 0, 1], False), + # All but SiluAndMul + ("all,-silu_and_mul", 1, [1, 0, 1, 1], True), + # All but ReLUSquaredActivation + ("-relu2", 1, [1, 1, 1, 0], True), + # GeluAndMul and SiluAndMul + ("none,-relu2,+gelu_and_mul,+silu_and_mul", 2, [0, 1, 1, 0], False), + # All but RMSNorm + ("-rms_norm", 2, [0, 1, 1, 1], True), + # # Default: none - ("-silu_and_mul,+relu2", 3, [False, False, False, True], False), # Only ReLUSquaredActivation - ("all,-rms_norm", 4, [False, True, True, True], True), # All but RMSNorm + # + # Only ReLUSquaredActivation + ("-silu_and_mul,+relu2", 3, [0, 0, 0, 1], False), + # All but RMSNorm + ("all,-rms_norm", 4, [0, 1, 1, 1], True), ]) -def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[bool], +def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int], default_on: bool): os.environ["VLLM_CUSTOM_OPS"] = env os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level) @@ -41,10 +53,10 @@ def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[bool], assert CustomOp.default_on() == default_on - assert RMSNorm(1024)._enabled() == ops_enabled[0] - assert SiluAndMul()._enabled() == ops_enabled[1] - assert GeluAndMul()._enabled() == ops_enabled[2] - assert ReLUSquaredActivation()._enabled() == ops_enabled[3] + assert RMSNorm(1024)._enabled() == bool(ops_enabled[0]) + assert SiluAndMul()._enabled() == bool(ops_enabled[1]) + assert GeluAndMul()._enabled() == bool(ops_enabled[2]) + assert ReLUSquaredActivation()._enabled() == bool(ops_enabled[3]) @pytest.mark.parametrize( From e78e4de2aafd9e01c2e27d3ac561528153b60652 Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 15 Oct 2024 17:05:06 -0400 Subject: [PATCH 06/10] Add custom op registry --- .../model_executor/test_enabled_custom_ops.py | 19 +++++++++++---- vllm/model_executor/custom_op.py | 24 ++++++++++++++++--- vllm/model_executor/layers/activation.py | 14 +++++------ vllm/model_executor/layers/fused_moe/layer.py | 2 +- vllm/model_executor/layers/layernorm.py | 4 ++-- .../model_executor/layers/rotary_embedding.py | 4 ++-- 6 files changed, 46 insertions(+), 21 deletions(-) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index edc8b2194bd7..759958a30e8c 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -53,10 +53,19 @@ def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int], assert CustomOp.default_on() == default_on - assert RMSNorm(1024)._enabled() == bool(ops_enabled[0]) - assert SiluAndMul()._enabled() == bool(ops_enabled[1]) - assert GeluAndMul()._enabled() == bool(ops_enabled[2]) - assert ReLUSquaredActivation()._enabled() == bool(ops_enabled[3]) + ops_enabled = [bool(x) for x in ops_enabled] + + assert RMSNorm(1024).enabled() == ops_enabled[0] + assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0] + + assert SiluAndMul().enabled() == ops_enabled[1] + assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1] + + assert GeluAndMul().enabled() == ops_enabled[2] + assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2] + + assert ReLUSquaredActivation().enabled() == ops_enabled[3] + assert CustomOp.op_registry["relu2"].enabled() == ops_enabled[3] @pytest.mark.parametrize( @@ -66,4 +75,4 @@ def test_enabled_ops_invalid(env: str): CustomOp.default_on.cache_clear() with pytest.raises(AssertionError): - RMSNorm(1024)._enabled() + RMSNorm(1024).enabled() diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 0d62e8eabca7..9cf3305a9f95 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -1,4 +1,5 @@ from functools import lru_cache +from typing import Dict, Type import torch.nn as nn @@ -61,7 +62,7 @@ def dispatch_forward(self): # NOTE(woosuk): Here we assume that vLLM was built for only one # specific backend. Currently, we do not support dynamic dispatching. - if not self._enabled(): + if not self.enabled(): return self.forward_native if is_hip(): @@ -76,7 +77,7 @@ def dispatch_forward(self): return self.forward_cuda @classmethod - def _enabled(cls) -> bool: + def enabled(cls) -> bool: enabled = f"+{cls.name}" in envs.VLLM_CUSTOM_OPS disabled = f"-{cls.name}" in envs.VLLM_CUSTOM_OPS assert not (enabled @@ -85,7 +86,7 @@ def _enabled(cls) -> bool: return (CustomOp.default_on() or enabled) and not disabled # On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR - # Specifying 'all' or 'none' will override this default. + # Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence. @staticmethod @lru_cache() def default_on() -> bool: @@ -94,3 +95,20 @@ def default_on() -> bool: assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" return envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR and \ not count_none > 0 or count_all > 0 + + # Dictionary of all custom ops. + # To check if an op with a name is enabled, call .enabled() on the class. + # Example: op_registry["my_op"].enabled() + op_registry: Dict[str, Type['CustomOp']] = {} + + # Decorator to register custom ops. + @classmethod + def register(cls, name: str): + + def decorator(op_cls): + assert name not in cls.op_registry, f"Duplicate op name: {name}" + op_cls.name = name + cls.op_registry[name] = op_cls + return op_cls + + return decorator diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index ed3217da471a..c48c12ecec66 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -13,6 +13,7 @@ from vllm.model_executor.utils import set_weight_attrs +@CustomOp.register("silu_and_mul") class SiluAndMul(CustomOp): """An activation function for SwiGLU. @@ -23,8 +24,6 @@ class SiluAndMul(CustomOp): return: (num_tokens, d) or (batch_size, seq_len, d) """ - name = "silu_and_mul" - def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" d = x.shape[-1] // 2 @@ -49,6 +48,7 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: return out +@CustomOp.register("gelu_and_mul") class GeluAndMul(CustomOp): """An activation function for GeGLU. @@ -58,7 +58,6 @@ class GeluAndMul(CustomOp): x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d) return: (batch_size, seq_len, d) or (num_tokens, d) """ - name = "gelu_and_mul" def __init__(self, approximate: str = "none"): super().__init__() @@ -99,8 +98,8 @@ def extra_repr(self) -> str: return f'approximate={repr(self.approximate)}' +@CustomOp.register("gelu_new") class NewGELU(CustomOp): - name = "gelu_new" def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" @@ -121,8 +120,8 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: return ops.gelu_new(x) +@CustomOp.register("gelu_fast") class FastGELU(CustomOp): - name = "gelu_fast" def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" @@ -142,9 +141,8 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: return ops.gelu_fast(x) +@CustomOp.register("quick_gelu") class QuickGELU(CustomOp): - name = "quick_gelu" - # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90 def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" @@ -168,11 +166,11 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: +@CustomOp.register("relu2") class ReLUSquaredActivation(CustomOp): """ Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2 """ - name = "relu2" def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index ba2fe2d0b97a..8dd36620e3fa 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -37,9 +37,9 @@ def apply(self, layer: torch.nn.Module, x: torch.Tensor, raise NotImplementedError +@CustomOp.register("unquantized_fused_moe") class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" - name = "unquantized_fused_moe" def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size: int, diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 552a81311af3..10fae84dab72 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -7,13 +7,13 @@ from vllm.model_executor.custom_op import CustomOp +@CustomOp.register("rms_norm") class RMSNorm(CustomOp): """Root mean square normalization. Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. Refer to https://arxiv.org/abs/1910.07467 """ - name = "rms_norm" def __init__( self, @@ -123,6 +123,7 @@ def extra_repr(self) -> str: return s +@CustomOp.register("gemma_rms_norm") class GemmaRMSNorm(CustomOp): """RMS normalization for Gemma. @@ -130,7 +131,6 @@ class GemmaRMSNorm(CustomOp): 1. x * (1 + w) instead of x * w. 2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w. """ - name = "gemma_rms_norm" def __init__( self, diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index f7f4570fa4f7..dc8cfb5fd9dd 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -72,9 +72,9 @@ def _apply_rotary_emb( return torch.stack((o1, o2), dim=-1).flatten(-2) +@CustomOp.register("rotary_embedding") class RotaryEmbedding(CustomOp): """Original rotary positional embedding.""" - name = "rotary_embedding" def __init__( self, @@ -469,7 +469,7 @@ def __init__( self.long_factor = long_factor scale = self.max_position_embeddings / \ - self.original_max_position_embeddings + self.original_max_position_embeddings if scale <= 1.0: scaling_factor = 1.0 else: From 8f1fc443ed155e587026120e5e1e42a86b493f56 Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 17 Oct 2024 10:20:22 -0400 Subject: [PATCH 07/10] PR comments: - warning when not registered - log when enabled - test subclass behavior - remove name from model FlippedSiluAndMul --- .../model_executor/test_enabled_custom_ops.py | 32 +++++++++++++------ vllm/model_executor/custom_op.py | 19 +++++++++-- vllm/model_executor/models/ultravox.py | 1 - 3 files changed, 40 insertions(+), 12 deletions(-) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 759958a30e8c..af267f804ffa 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -10,6 +10,12 @@ from vllm.model_executor.layers.layernorm import RMSNorm +# Registered subclass for test +@CustomOp.register("relu3") +class Relu3(ReLUSquaredActivation): + pass + + @pytest.mark.parametrize( "env, torch_level, ops_enabled, default_on", [ @@ -25,21 +31,21 @@ # # All but SiluAndMul ("+rms_norm,-silu_and_mul", 0, [1, 0, 1, 1], True), - # Only ReLUSquaredActivation - ("none,-rms_norm,+relu2", 0, [0, 0, 0, 1], False), + # Only ReLU3 + ("none,-rms_norm,+relu3", 0, [0, 0, 0, 1], False), # All but SiluAndMul ("all,-silu_and_mul", 1, [1, 0, 1, 1], True), - # All but ReLUSquaredActivation - ("-relu2", 1, [1, 1, 1, 0], True), + # All but ReLU3 (even if ReLU2 is on) + ("-relu3,relu2", 1, [1, 1, 1, 0], True), # GeluAndMul and SiluAndMul - ("none,-relu2,+gelu_and_mul,+silu_and_mul", 2, [0, 1, 1, 0], False), + ("none,-relu3,+gelu_and_mul,+silu_and_mul", 2, [0, 1, 1, 0], False), # All but RMSNorm ("-rms_norm", 2, [0, 1, 1, 1], True), # # Default: none # - # Only ReLUSquaredActivation - ("-silu_and_mul,+relu2", 3, [0, 0, 0, 1], False), + # Only ReLU3 + ("-silu_and_mul,+relu3", 3, [0, 0, 0, 1], False), # All but RMSNorm ("all,-rms_norm", 4, [0, 1, 1, 1], True), ]) @@ -64,8 +70,16 @@ def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int], assert GeluAndMul().enabled() == ops_enabled[2] assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2] - assert ReLUSquaredActivation().enabled() == ops_enabled[3] - assert CustomOp.op_registry["relu2"].enabled() == ops_enabled[3] + # If registered, subclasses should follow their own name + assert Relu3().enabled() == ops_enabled[3] + assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3] + + # Unregistered subclass + class SiluAndMul2(SiluAndMul): + pass + + # Subclasses should not require registration + assert SiluAndMul2().enabled() == SiluAndMul().enabled() @pytest.mark.parametrize( diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 9cf3305a9f95..508baa8e3ff9 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -5,8 +5,11 @@ import vllm.envs as envs from vllm.compilation.levels import CompilationLevel +from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import is_cpu, is_hip, is_xpu +from vllm.utils import is_cpu, is_hip, is_xpu, print_warning_once + +logger = init_logger(__name__) class CustomOp(nn.Module): @@ -62,7 +65,11 @@ def dispatch_forward(self): # NOTE(woosuk): Here we assume that vLLM was built for only one # specific backend. Currently, we do not support dynamic dispatching. - if not self.enabled(): + enabled = self.enabled() + logger.debug("custom op %s %s", self.__class__.name, + "enabled" if enabled else "disabled") + + if not enabled: return self.forward_native if is_hip(): @@ -78,6 +85,14 @@ def dispatch_forward(self): @classmethod def enabled(cls) -> bool: + # if no name, then it was not registered + if not hasattr(cls, "name"): + print_warning_once( + f"Custom op {cls.__name__} was not registered, " + f"which means it won't appear in the op registry. " + f"It will be enabled/disabled based on the global settings.") + return CustomOp.default_on() + enabled = f"+{cls.name}" in envs.VLLM_CUSTOM_OPS disabled = f"-{cls.name}" in envs.VLLM_CUSTOM_OPS assert not (enabled diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 063f4d22e20a..e162e3af008e 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -231,7 +231,6 @@ def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor: class FlippedSiluAndMul(SiluAndMul): """Ultravox is trained with SwiGLU with flipped halves.""" - name = "flipped_silu_and_mul" def forward(self, x: torch.Tensor): a, b = x.chunk(2, dim=-1) From fae9e2703693dcf0cb27358c57b680acd1de954e Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 17 Oct 2024 10:21:37 -0400 Subject: [PATCH 08/10] Fix comment in envs --- vllm/envs.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index e88a96f479ac..5c32f2ad533e 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -206,13 +206,13 @@ def get_default_config_root(): os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"), "VLLM_TORCH_COMPILE_LEVEL": lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")), - # Fine-grained control over which custom ops to enable. + # Fine-grained control over which custom ops to enable/disable. # Use 'all' to enable all, 'none' to disable all. - # Also specify a list of custom op names to enable, - # or disable if the name is prefixed with a '-'. + # Also specify a list of custom op names to enable (prefixed with a '+'), + # or disable (prefixed with a '-'). # Examples: # - 'all,-op1' to enable all except op1 - # - 'none,op1,op2' to enable only op1 and op2 + # - 'none,+op1,+op2' to enable only op1 and op2 # By default, all custom ops are enabled when running without Inductor # and disabled when running with Inductor (compile_level >= Inductor). "VLLM_CUSTOM_OPS": From 5201dc6c8323dc3061e84662ba6be1ec36352ca0 Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 17 Oct 2024 10:23:37 -0400 Subject: [PATCH 09/10] Improve custom_op comment --- vllm/model_executor/custom_op.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 508baa8e3ff9..549be116772c 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -111,9 +111,11 @@ def default_on() -> bool: return envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR and \ not count_none > 0 or count_all > 0 - # Dictionary of all custom ops. + # Dictionary of all custom ops (classes, indexed by registered name). # To check if an op with a name is enabled, call .enabled() on the class. - # Example: op_registry["my_op"].enabled() + # Examples: + # - MyOp.enabled() + # - op_registry["my_op"].enabled() op_registry: Dict[str, Type['CustomOp']] = {} # Decorator to register custom ops. From 408d576795fdd58b0f0174614832d72df850005c Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 17 Oct 2024 11:35:10 -0400 Subject: [PATCH 10/10] Lazy init for activations --- vllm/model_executor/layers/activation.py | 26 ++++++++++++++++-------- vllm/utils.py | 22 ++++++++++++++++++++ 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index ac1bbdc3e729..cf99306c9cae 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -11,6 +11,7 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.utils import set_weight_attrs +from vllm.utils import LazyDict @CustomOp.register("fatrelu_and_mul") @@ -250,15 +251,22 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param_data.copy_(loaded_weight) -_ACTIVATION_REGISTRY = { - "gelu": nn.GELU(), - "gelu_fast": FastGELU(), - "gelu_new": NewGELU(), - "gelu_pytorch_tanh": nn.GELU(approximate="tanh"), - "relu": nn.ReLU(), - "relu2": ReLUSquaredActivation(), - "quick_gelu": QuickGELU(), -} +_ACTIVATION_REGISTRY = LazyDict({ + "gelu": + lambda: nn.GELU(), + "gelu_fast": + lambda: FastGELU(), + "gelu_new": + lambda: NewGELU(), + "gelu_pytorch_tanh": + lambda: nn.GELU(approximate="tanh"), + "relu": + lambda: nn.ReLU(), + "relu2": + lambda: ReLUSquaredActivation(), + "quick_gelu": + lambda: QuickGELU(), +}) def get_act_fn( diff --git a/vllm/utils.py b/vllm/utils.py index 8debae52b288..07769da3c86d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -17,6 +17,7 @@ import warnings import weakref from asyncio import FIRST_COMPLETED, ensure_future +from collections.abc import Mapping from functools import lru_cache, partial, wraps from platform import uname from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic, @@ -1442,3 +1443,24 @@ def dec(self, num=1): @property def value(self): return self._value + + +# Adapted from: https://stackoverflow.com/a/47212782/5082708 +class LazyDict(Mapping, Generic[T]): + + def __init__(self, factory: Dict[str, Callable[[], T]]): + self._factory = factory + self._dict: Dict[str, T] = {} + + def __getitem__(self, key) -> T: + if key not in self._dict: + if key not in self._factory: + raise KeyError(key) + self._dict[key] = self._factory[key]() + return self._dict[key] + + def __iter__(self): + return iter(self._factory) + + def __len__(self): + return len(self._factory)