Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions vllm/model_executor/layers/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
self.weight_bits = weight_bits
self.group_size = group_size
self.zero_point = zero_point
self.modules_to_not_convert = modules_to_not_convert or []
self.ignored_modules = modules_to_not_convert or []

if self.weight_bits != 4:
raise ValueError(
Expand All @@ -42,7 +42,7 @@ def __repr__(self) -> str:
return (f"AWQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"zero_point={self.zero_point}, "
f"modules_to_not_convert={self.modules_to_not_convert})")
f"modules_to_not_convert={self.ignored_modules})")

def get_name(self) -> str:
return "awq"
Expand Down Expand Up @@ -75,14 +75,14 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["LinearMethodBase"]:
if isinstance(layer, LinearBase):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
if is_layer_skipped_awq(prefix, self.ignored_modules):
return UnquantizedLinearMethod()
return AWQLinearMethod(self)
return None


def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
return any(module_name in prefix for module_name in modules_to_not_convert)
def is_layer_skipped_awq(prefix: str, ignored_modules: List[str]):
return any(module_name in prefix for module_name in ignored_modules)


class AWQLinearMethod(LinearMethodBase):
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/quantization/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ class QuantizationConfig(ABC):

def __init__(self):
super().__init__()
# mapping is updated by models as they initialize
# These attributes are updated by models as they initialize
self.packed_modules_mapping: Dict[str, List[str]] = dict()
self.ignored_modules: List[str] = list()

@abstractmethod
def get_name(self) -> str:
Expand Down
14 changes: 7 additions & 7 deletions vllm/model_executor/layers/quantization/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
self.llm_int8_skip_modules = llm_int8_skip_modules or []
self.ignored_modules = llm_int8_skip_modules or []
self.llm_int8_threshold = llm_int8_threshold

if self.bnb_4bit_quant_storage not in ["uint8"]:
Expand All @@ -52,7 +52,7 @@ def __repr__(self) -> str:
f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, "
f"bnb_4bit_quant_storage={self.bnb_4bit_quant_storage}, "
f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, "
f"llm_int8_skip_modules={self.llm_int8_skip_modules})")
f"llm_int8_skip_modules={self.ignored_modules})")

@classmethod
def get_name(self) -> str:
Expand Down Expand Up @@ -122,25 +122,25 @@ def get_safe_value(config, keys, default_value=None):
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["LinearMethodBase"]:
if isinstance(layer, LinearBase):
if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules):
if is_layer_skipped_bnb(prefix, self.ignored_modules):
return UnquantizedLinearMethod()
return BitsAndBytesLinearMethod(self)
return None


def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]):
def is_layer_skipped_bnb(prefix: str, ignored_modules: List[str]):
# Split the prefix into its dot-separated components
components = prefix.split('.')

# Check if any of the skip modules exactly matches any component
substr_check = any(module_name in components
for module_name in llm_int8_skip_modules)
for module_name in ignored_modules)

# Allow certain layers to not be quantized
set_components = set(".".join(components[:i + 1])
for i in range(len(components)))
set_llm_int8_skip_modules = set(llm_int8_skip_modules)
prefix_check = len(set_llm_int8_skip_modules & set_components) != 0
set_ignored_modules = set(ignored_modules)
prefix_check = len(set_ignored_modules & set_components) != 0

return substr_check or prefix_check

Expand Down
47 changes: 37 additions & 10 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# SPDX-License-Identifier: Apache-2.0

from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
Protocol, Type, Union, overload, runtime_checkable)
import inspect
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Literal,
Optional, Protocol, Type, Union, overload,
runtime_checkable)

import torch
from torch import Tensor
Expand Down Expand Up @@ -452,28 +454,53 @@ class SupportsQuant:
quant_config: Optional[QuantizationConfig] = None

def __new__(cls, *args, **kwargs) -> Self:
from .utils import WeightsMapper # avoid circular import

instance = super().__new__(cls)
quant_config = cls._find_quant_config(*args, **kwargs)
bound_args = inspect.signature(cls.__init__).bind(
instance, *args, **kwargs).arguments

quant_config = cls._find_quant_config(bound_args)
prefix = cls._find_prefix(bound_args)
packed_modules_mapping = cls.packed_modules_mapping
hf_to_vllm_mapper: WeightsMapper = getattr(cls, "hf_to_vllm_mapper",
WeightsMapper())

if quant_config is not None:
# 1. update qconfig's packed_modules_mapppings
# currently takes union, in the future could be more precise
# using prefix and hf_to_vllm_mapper
quant_config.packed_modules_mapping.update(packed_modules_mapping)

# 2. update qconfig's ignored modules
quant_config.ignored_modules = [
prefix + hf_to_vllm_mapper._map_name(module[len(prefix):])
if module.startswith(prefix) else module
for module in quant_config.ignored_modules
]

# 3. set module's quantization config
instance.quant_config = quant_config
instance.quant_config.packed_modules_mapping.update(
cls.packed_modules_mapping)

return instance

@staticmethod
def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]:
def _find_quant_config(
bound_args: Dict[str, Any]) -> Optional[QuantizationConfig]:
from vllm.config import VllmConfig # avoid circular import

args_values = list(args) + list(kwargs.values())
for arg in args_values:
for arg in bound_args.values():
if isinstance(arg, VllmConfig):
return arg.quant_config

if isinstance(arg, QuantizationConfig):
elif isinstance(arg, QuantizationConfig):
return arg

return None

@staticmethod
def _find_prefix(bound_args: Dict[str, Any]) -> str:
return bound_args.get("prefix", "")


@runtime_checkable
class SupportsTranscription(Protocol):
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
from vllm.transformers_utils.config import uses_mrope

from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
SupportsMultiModal, SupportsPP, SupportsQuant)
from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo,
apply_rotary_pos_emb_vision)
Expand Down Expand Up @@ -764,7 +764,8 @@ def _get_mm_fields_config(
info=Qwen2_5_VLProcessingInfo,
dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA, SupportsPP):
SupportsLoRA, SupportsPP,
SupportsQuant):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down