Skip to content
Merged
92 changes: 92 additions & 0 deletions tests/model_executor/test_enabled_custom_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import os
from typing import List

import pytest

from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import (GeluAndMul,
ReLUSquaredActivation,
SiluAndMul)
from vllm.model_executor.layers.layernorm import RMSNorm


# Registered subclass for test
@CustomOp.register("relu3")
class Relu3(ReLUSquaredActivation):
pass


@pytest.mark.parametrize(
"env, torch_level, ops_enabled, default_on",
[
# Default values based on compile level
("", 0, [True] * 4, True),
("", 1, [True] * 4, True),
("", 2, [True] * 4, True), # All by default
("", 3, [False] * 4, False),
("", 4, [False] * 4, False), # None by default
# Explicitly enabling/disabling
#
# Default: all
#
# All but SiluAndMul
("+rms_norm,-silu_and_mul", 0, [1, 0, 1, 1], True),
# Only ReLU3
("none,-rms_norm,+relu3", 0, [0, 0, 0, 1], False),
# All but SiluAndMul
("all,-silu_and_mul", 1, [1, 0, 1, 1], True),
# All but ReLU3 (even if ReLU2 is on)
("-relu3,relu2", 1, [1, 1, 1, 0], True),
# GeluAndMul and SiluAndMul
("none,-relu3,+gelu_and_mul,+silu_and_mul", 2, [0, 1, 1, 0], False),
# All but RMSNorm
("-rms_norm", 2, [0, 1, 1, 1], True),
#
# Default: none
#
# Only ReLU3
("-silu_and_mul,+relu3", 3, [0, 0, 0, 1], False),
# All but RMSNorm
("all,-rms_norm", 4, [0, 1, 1, 1], True),
])
def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int],
default_on: bool):
os.environ["VLLM_CUSTOM_OPS"] = env
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level)

# Reset default_on (computed once):
CustomOp.default_on.cache_clear()

assert CustomOp.default_on() == default_on

ops_enabled = [bool(x) for x in ops_enabled]

assert RMSNorm(1024).enabled() == ops_enabled[0]
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]

assert SiluAndMul().enabled() == ops_enabled[1]
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]

assert GeluAndMul().enabled() == ops_enabled[2]
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]

# If registered, subclasses should follow their own name
assert Relu3().enabled() == ops_enabled[3]
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]

# Unregistered subclass
class SiluAndMul2(SiluAndMul):
pass

# Subclasses should not require registration
assert SiluAndMul2().enabled() == SiluAndMul().enabled()


@pytest.mark.parametrize(
"env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"])
def test_enabled_ops_invalid(env: str):
os.environ["VLLM_CUSTOM_OPS"] = env
CustomOp.default_on.cache_clear()

with pytest.raises(AssertionError):
RMSNorm(1024).enabled()
13 changes: 12 additions & 1 deletion vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1: bool = False
VLLM_TORCH_COMPILE_LEVEL: int = 0
VLLM_CUSTOM_OPS: List[str] = []
VLLM_DISABLED_KERNELS: List[str] = []


Expand Down Expand Up @@ -206,7 +207,17 @@ def get_default_config_root():
os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
"VLLM_TORCH_COMPILE_LEVEL":
lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")),

# Fine-grained control over which custom ops to enable/disable.
# Use 'all' to enable all, 'none' to disable all.
# Also specify a list of custom op names to enable (prefixed with a '+'),
# or disable (prefixed with a '-').
# Examples:
# - 'all,-op1' to enable all except op1
# - 'none,+op1,+op2' to enable only op1 and op2
# By default, all custom ops are enabled when running without Inductor
# and disabled when running with Inductor (compile_level >= Inductor).
"VLLM_CUSTOM_OPS":
lambda: os.environ.get("VLLM_CUSTOM_OPS", "").replace(" ", "").split(","),
# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK":
Expand Down
68 changes: 64 additions & 4 deletions vllm/model_executor/custom_op.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
from functools import lru_cache
from typing import Dict, Type

import torch.nn as nn

import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip, is_xpu
from vllm.utils import is_cpu, is_hip, is_xpu, print_warning_once

logger = init_logger(__name__)


class CustomOp(nn.Module):
"""
Base class for custom ops.
Dispatches the forward method to the appropriate backend.
"""

def __init__(self, *args, **kwargs):
def __init__(self):
super().__init__()
self._forward_method = self.dispatch_forward()

Expand All @@ -17,7 +27,6 @@ def forward(self, *args, **kwargs):

def forward_native(self, *args, **kwargs):
"""PyTorch-native implementation of the forward method.

This method is optional. If implemented, it can be used with compilers
such as torch.compile or PyTorch XLA. Also, it can be used for testing
purposes.
Expand Down Expand Up @@ -56,7 +65,11 @@ def dispatch_forward(self):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.

if envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.INDUCTOR:
enabled = self.enabled()
logger.debug("custom op %s %s", self.__class__.name,
"enabled" if enabled else "disabled")

if not enabled:
return self.forward_native

if is_hip():
Expand All @@ -69,3 +82,50 @@ def dispatch_forward(self):
return self.forward_xpu
else:
return self.forward_cuda

@classmethod
def enabled(cls) -> bool:
# if no name, then it was not registered
if not hasattr(cls, "name"):
print_warning_once(
f"Custom op {cls.__name__} was not registered, "
f"which means it won't appear in the op registry. "
f"It will be enabled/disabled based on the global settings.")
return CustomOp.default_on()

enabled = f"+{cls.name}" in envs.VLLM_CUSTOM_OPS
disabled = f"-{cls.name}" in envs.VLLM_CUSTOM_OPS
assert not (enabled
and disabled), f"Cannot enable and disable {cls.name}"

return (CustomOp.default_on() or enabled) and not disabled

# On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR
# Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence.
@staticmethod
@lru_cache()
def default_on() -> bool:
count_none = envs.VLLM_CUSTOM_OPS.count("none")
count_all = envs.VLLM_CUSTOM_OPS.count("all")
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
return envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR and \
not count_none > 0 or count_all > 0

# Dictionary of all custom ops (classes, indexed by registered name).
# To check if an op with a name is enabled, call .enabled() on the class.
# Examples:
# - MyOp.enabled()
# - op_registry["my_op"].enabled()
op_registry: Dict[str, Type['CustomOp']] = {}

# Decorator to register custom ops.
@classmethod
def register(cls, name: str):

def decorator(op_cls):
assert name not in cls.op_registry, f"Duplicate op name: {name}"
op_cls.name = name
cls.op_registry[name] = op_cls
return op_cls

return decorator
36 changes: 25 additions & 11 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import LazyDict


@CustomOp.register("fatrelu_and_mul")
class FatreluAndMul(CustomOp):
"""An activation function for FATReLU.

The function computes x -> FATReLU(x[:d]) * x[d:] where
d = x.shape[-1] // 2.
This is used in openbmb/MiniCPM-S-1B-sft.
Expand All @@ -40,6 +42,7 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_native(x)


@CustomOp.register("silu_and_mul")
class SiluAndMul(CustomOp):
"""An activation function for SwiGLU.

Expand Down Expand Up @@ -74,6 +77,7 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
return out


@CustomOp.register("gelu_and_mul")
class GeluAndMul(CustomOp):
"""An activation function for GeGLU.

Expand Down Expand Up @@ -123,6 +127,7 @@ def extra_repr(self) -> str:
return f'approximate={repr(self.approximate)}'


@CustomOp.register("gelu_new")
class NewGELU(CustomOp):

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -144,6 +149,7 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
return ops.gelu_new(x)


@CustomOp.register("gelu_fast")
class FastGELU(CustomOp):

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -164,8 +170,8 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
return ops.gelu_fast(x)


@CustomOp.register("quick_gelu")
class QuickGELU(CustomOp):

# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
Expand All @@ -189,6 +195,7 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:


@CustomOp.register("relu2")
class ReLUSquaredActivation(CustomOp):
"""
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
Expand Down Expand Up @@ -244,15 +251,22 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data.copy_(loaded_weight)


_ACTIVATION_REGISTRY = {
"gelu": nn.GELU(),
"gelu_fast": FastGELU(),
"gelu_new": NewGELU(),
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
"relu": nn.ReLU(),
"relu2": ReLUSquaredActivation(),
"quick_gelu": QuickGELU(),
}
_ACTIVATION_REGISTRY = LazyDict({
"gelu":
lambda: nn.GELU(),
"gelu_fast":
lambda: FastGELU(),
"gelu_new":
lambda: NewGELU(),
"gelu_pytorch_tanh":
lambda: nn.GELU(approximate="tanh"),
"relu":
lambda: nn.ReLU(),
"relu2":
lambda: ReLUSquaredActivation(),
"quick_gelu":
lambda: QuickGELU(),
})


def get_act_fn(
Expand Down
5 changes: 1 addition & 4 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ def apply(self, layer: torch.nn.Module, x: torch.Tensor,
raise NotImplementedError


@CustomOp.register("unquantized_fused_moe")
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""

def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):

# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size,
Expand Down Expand Up @@ -74,7 +74,6 @@ def apply(
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:

return self.forward(x=x,
layer=layer,
router_logits=router_logits,
Expand All @@ -97,7 +96,6 @@ def forward_cuda(
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:

from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts)

Expand Down Expand Up @@ -134,7 +132,6 @@ def forward_tpu(
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:

from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
assert not use_grouped_topk
assert num_expert_group is None
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from vllm.model_executor.custom_op import CustomOp


@CustomOp.register("rms_norm")
class RMSNorm(CustomOp):
"""Root mean square normalization.

Expand Down Expand Up @@ -122,6 +123,7 @@ def extra_repr(self) -> str:
return s


@CustomOp.register("gemma_rms_norm")
class GemmaRMSNorm(CustomOp):
"""RMS normalization for Gemma.

Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def _apply_rotary_emb(
return torch.stack((o1, o2), dim=-1).flatten(-2)


@CustomOp.register("rotary_embedding")
class RotaryEmbedding(CustomOp):
"""Original rotary positional embedding."""

Expand Down Expand Up @@ -468,7 +469,7 @@ def __init__(
self.long_factor = long_factor

scale = self.max_position_embeddings / \
self.original_max_position_embeddings
self.original_max_position_embeddings
if scale <= 1.0:
scaling_factor = 1.0
else:
Expand Down
Loading