diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 77b2ef0fce5f..111b49ab8dd2 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -26,7 +26,7 @@ from vllm.transformers_utils.config import ( get_cross_encoder_activation_function) -from .interfaces import SupportsCrossEncoding, SupportsV0Only +from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only from .utils import WeightsMapper, maybe_prefix @@ -313,7 +313,8 @@ def forward(self, hidden_states: torch.Tensor, return hidden_states -class BertModel(nn.Module): +class BertModel(nn.Module, SupportsQuant): + packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]} def __init__(self, *, @@ -385,7 +386,7 @@ def load_weights(self, weights: Iterable[Tuple[str, return loaded_params -class BertEmbeddingModel(nn.Module, SupportsV0Only): +class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant): """A model that uses Bert to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for @@ -443,7 +444,8 @@ def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: softmax=False) -class BertForSequenceClassification(nn.Module, SupportsCrossEncoding): +class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, + SupportsQuant): """A model that uses Bert to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index bedbdceb7721..f3d488926d09 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -16,6 +16,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from .interfaces import SupportsQuant + def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int: assert image_size % patch_size == 0 @@ -243,9 +245,10 @@ def forward(self, inputs_embeds: torch.Tensor): return hidden_states -class BlipVisionModel(nn.Module): +class BlipVisionModel(nn.Module, SupportsQuant): config_class = BlipVisionConfig main_input_name = "pixel_values" + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} def __init__( self, diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 7adca4f0dc86..db9d42f5b86a 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -24,7 +24,8 @@ from vllm.sequence import IntermediateTensors from .blip import BlipVisionModel -from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, + SupportsQuant) from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -498,7 +499,8 @@ def _get_prompt_updates( @MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor, info=Blip2ProcessingInfo, dummy_inputs=Blip2DummyInputsBuilder) -class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): +class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, + SupportsQuant): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 50f48f91798a..f960075b98bc 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -42,7 +42,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP, SupportsV0Only +from .interfaces import SupportsPP, SupportsQuant, SupportsV0Only from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -279,7 +279,7 @@ def forward( return hidden_states -class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only): +class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__()