diff --git a/tests/models/language/pooling/test_st_projector.py b/tests/models/language/pooling/test_st_projector.py new file mode 100644 index 000000000000..42eff7f8b814 --- /dev/null +++ b/tests/models/language/pooling/test_st_projector.py @@ -0,0 +1,201 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +import numpy as np +import pytest +from scipy.spatial.distance import cosine + +from ...utils import EmbedModelInfo +from .mteb_utils import MTEB_EMBED_TOL, mteb_test_embed_models + + +def _get_vllm_embeddings(vllm_runner, model_info: EmbedModelInfo, + test_texts: list[str]): + """Get embeddings from vLLM.""" + vllm_extra_kwargs: dict[str, Any] = {} + if model_info.architecture == "GteNewModel": + vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} + + with vllm_runner( + model_info.name, + runner="pooling", + max_model_len=None, + trust_remote_code=True, + **vllm_extra_kwargs, + ) as vllm_model: + embeddings = vllm_model.encode(test_texts) + + # Extract tensor/numpy data + data = [] + for emb in embeddings: + if hasattr(emb, "outputs"): + data.append(emb.outputs.data.cpu().numpy()) + else: + data.append(emb.cpu().numpy() if hasattr(emb, "cpu") else emb) + return np.array(data) + + +def _get_hf_embeddings(hf_runner, model_info: EmbedModelInfo, + test_texts: list[str]): + """Get embeddings from HuggingFace ST interface.""" + with hf_runner( + model_info.name, + is_sentence_transformer=True, + dtype="float32", + ) as hf_model: + embeddings = hf_model.encode(test_texts) + if hasattr(embeddings, "cpu"): + return embeddings.cpu().numpy() + return np.array(embeddings) + + +# ST models with projector (Dense) layers +ST_PROJECTOR_MODELS = [ + EmbedModelInfo( + "TencentBAC/Conan-embedding-v1", + architecture="BertModel", + enable_test=True, + ), +] + + +@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS) +def test_embed_models_mteb(hf_runner, vllm_runner, + model_info: EmbedModelInfo) -> None: + """MTEB test for ST projector models to detect numerical issues.""" + vllm_extra_kwargs: dict[str, Any] = {} + if model_info.architecture == "BertModel": + # Ensure BertEmbeddingModel is used for embedding models + vllm_extra_kwargs["trust_remote_code"] = True + + mteb_test_embed_models(hf_runner, + vllm_runner, + model_info, + vllm_extra_kwargs, + atol=MTEB_EMBED_TOL) + + +@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS) +def test_st_projector_loading(vllm_runner, model_info: EmbedModelInfo) -> None: + """Ensure projector models load and output expected dim.""" + if not model_info.enable_test: + pytest.skip("Skipping test.") + + test_texts = ["This is a test sentence."] + embeddings_data = _get_vllm_embeddings(vllm_runner, model_info, test_texts) + + actual_dim = embeddings_data.shape[-1] + expected_dim = 1792 + assert actual_dim == expected_dim, ( + f"Expected {expected_dim}, got {actual_dim}") + + +@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS) +def test_compare_with_hf_dimensions(hf_runner, vllm_runner, + model_info: EmbedModelInfo) -> None: + """Compare embedding dimensions between vLLM and HuggingFace.""" + if not model_info.enable_test: + pytest.skip("Skipping test.") + + test_texts = ["This is a test sentence for dimension comparison."] + + vllm_data = _get_vllm_embeddings(vllm_runner, model_info, test_texts) + hf_data = _get_hf_embeddings(hf_runner, model_info, test_texts) + + vllm_dim = vllm_data.shape[-1] + hf_dim = hf_data.shape[-1] + + assert vllm_dim == hf_dim, ("Embedding dim mismatch: " + f"vLLM {vllm_dim} vs HF {hf_dim}") + print(f"✓ Embedding dimensions match: {vllm_dim}") + + +@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS) +def test_embedding_numerical_similarity(hf_runner, vllm_runner, + model_info: EmbedModelInfo) -> None: + """Numerical similarity between vLLM and HF embeddings.""" + if not model_info.enable_test: + pytest.skip("Skipping test.") + + test_texts = [ + "This is a test sentence for numerical comparison.", + "Another sentence to verify embedding quality.", + "机器学习是人工智能的一个重要分支。", # Chinese test + ] + + vllm_data = _get_vllm_embeddings(vllm_runner, model_info, test_texts) + hf_data = _get_hf_embeddings(hf_runner, model_info, test_texts) + + assert vllm_data.shape == hf_data.shape, ( + "Shape mismatch: " + f"vLLM {vllm_data.shape} vs HF {hf_data.shape}") + + print(f"Embedding shape: {vllm_data.shape}") + print(f"Embedding dimension: {vllm_data.shape[-1]}") + + similarities = [] + for i, text in enumerate(test_texts): + vllm_emb = vllm_data[i] + hf_emb = hf_data[i] + + similarity = 1 - cosine(vllm_emb, hf_emb) + similarities.append(similarity) + + preview = text[:50] + ("..." if len(text) > 50 else "") + print(f"Text {i + 1}: '{preview}'") + print(f" Cosine similarity: {similarity:.6f}") + + min_similarity = 0.95 + assert similarity > min_similarity, ( + f"Text {i + 1} similarity too low: " + f"{similarity:.6f} < {min_similarity}\n" + f"vLLM norm: {np.linalg.norm(vllm_emb):.6f}, " + f"HF norm: {np.linalg.norm(hf_emb):.6f}") + + avg_similarity = np.mean(similarities) + print(f"\nAverage cosine similarity: {avg_similarity:.6f}") + + assert avg_similarity > 0.98, ( + f"Average similarity too low: {avg_similarity:.6f} < 0.98") + print("✓ All numerical similarity tests passed!") + + +@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS) +def test_embedding_quality_checks(vllm_runner, + model_info: EmbedModelInfo) -> None: + """Basic quality checks: non-zero, non-constant, distinct.""" + if not model_info.enable_test: + pytest.skip("Skipping test.") + + test_texts = [ + "First test sentence.", + "Second different sentence.", + "Completely different content here.", + ] + + embeddings_data = _get_vllm_embeddings(vllm_runner, model_info, test_texts) + + print(f"Embeddings shape: {embeddings_data.shape}") + + # Non-zero and non-constant + for i, emb in enumerate(embeddings_data): + norm = np.linalg.norm(emb) + print(f"Embedding {i + 1} L2 norm: {norm:.6f}") + assert norm > 1e-6, ( + f"Embedding {i + 1} too close to zero: norm={norm}") + + std = np.std(emb) + print(f"Embedding {i + 1} std: {std:.6f}") + assert std > 1e-6, ( + f"Embedding {i + 1} too close to constant: std={std}") + + # Different texts should differ + for i in range(len(embeddings_data)): + for j in range(i + 1, len(embeddings_data)): + sim = 1 - cosine(embeddings_data[i], embeddings_data[j]) + print(f"Similarity between text {i + 1} and {j + 1}: {sim:.6f}") + assert sim < 0.99, ("Embeddings too similar: " + f"{i + 1} vs {j + 1} -> {sim:.6f}") + + print("✓ All embedding quality checks passed!") diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 7ce44174ead6..c0891171dcfb 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -392,12 +392,25 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): lambda: nn.SiLU(), "quick_gelu": lambda: QuickGELU(), + "tanh": + lambda: nn.Tanh(), + "sigmoid": + lambda: nn.Sigmoid(), + "swish": + lambda: nn.SiLU(), }) def get_act_fn(act_fn_name: str) -> nn.Module: """Get an activation function by name.""" act_fn_name = act_fn_name.lower() + + if act_fn_name.startswith("torch.nn.modules."): + activation_name = act_fn_name.split(".")[-1] + if activation_name == "identity": + return nn.Identity() + act_fn_name = activation_name + if act_fn_name not in _ACTIVATION_REGISTRY: raise ValueError( f"Activation function {act_fn_name!r} is not supported.") diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index e2162e5cbf95..a19b9d5cdc1c 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from enum import IntEnum from itertools import groupby -from typing import Callable, Optional, TypeVar, Union +from typing import Callable, Optional, TypeVar, Union, cast import torch import torch.nn as nn @@ -77,13 +77,17 @@ def for_encode(pooler_config: PoolerConfig): return SimplePooler.from_config(resolved_config) @staticmethod - def for_embed(pooler_config: PoolerConfig): + def for_embed( + pooler_config: PoolerConfig, + *, + projector: Optional[nn.Module] = None, + ): resolved_config = ResolvedPoolingConfig.from_config( task="embed", pooler_config=pooler_config, ) - return SimplePooler.from_config(resolved_config) + return SimplePooler.from_config(resolved_config, projector=projector) @staticmethod def for_classify( @@ -454,12 +458,77 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], class EmbeddingPoolerHead(PoolerHead): - def __init__(self) -> None: + def __init__(self, projector: Optional[nn.Module] = None) -> None: super().__init__(activation=PoolerNormalize()) + self.projector = projector + self._projector_dim_checked = False + + def _sync_projector_to_ref(self, ref_tensor: torch.Tensor) -> None: + """Ensure projector is on correct device with float32 dtype.""" + if self.projector is None: + return + + projector = cast(nn.Module, self.projector) + try: + proj_device = next(projector.parameters()).device + if proj_device != ref_tensor.device: + projector.to(device=ref_tensor.device, dtype=torch.float32) + # Ensure all parameters are float32 + for param in projector.parameters(): + param.data = param.data.to(torch.float32) + except StopIteration: + # Empty projector, skip device check + pass + + def _validate_projector_dimensions(self, ref_tensor: torch.Tensor) -> None: + """Validate projector input dimensions match pooled output.""" + if self.projector is None: + return + + projector = cast(nn.Module, self.projector) + first_linear = None + for module in projector.modules(): + if isinstance(module, nn.Linear): + first_linear = module + break + + if first_linear is not None: + expected_dim = first_linear.in_features + actual_dim = ref_tensor.shape[-1] + if expected_dim != actual_dim: + raise ValueError( + f"Dimension mismatch: Dense projector expects " + f"input dim {expected_dim}, but pooled output " + f"has dim {actual_dim}") def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], pooling_metadata: PoolingMetadata): + # Apply ST projector + if self.projector is not None: + if isinstance(pooled_data, list) and len(pooled_data) == 0: + pass # Skip projection for empty inputs + else: + projector = cast(nn.Module, self.projector) + ref = pooled_data[0] if isinstance(pooled_data, + list) else pooled_data + + self._sync_projector_to_ref(ref) + + if not self._projector_dim_checked: + self._validate_projector_dimensions(ref) + self._projector_dim_checked = True + + def _proj(x: torch.Tensor) -> torch.Tensor: + orig_dtype = x.dtype + y = projector(x.to(torch.float32)) + return y.to(orig_dtype) + + if isinstance(pooled_data, torch.Tensor): + pooled_data = _proj(pooled_data) + else: + pooled_data = [_proj(t) for t in pooled_data] + pooling_params = get_pooling_params(pooling_metadata) # for matryoshka representation @@ -530,12 +599,13 @@ class SimplePooler(Pooler): def from_config( cls, pooler_config: ResolvedPoolingConfig, + projector: Optional[nn.Module] = None, ) -> "SimplePooler": pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type) if pooler_config.task == "embed": - head = EmbeddingPoolerHead() + head = EmbeddingPoolerHead(projector=projector) elif pooler_config.task == "encode": - head = RewardPoolerHead() + head = EmbeddingPoolerHead() # no projector else: raise NotImplementedError(f"Unknown task: {pooler_config.task}") return cls(pooling, head) diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 1dbe70f84a62..371d0e5f7f52 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -7,6 +7,9 @@ import torch import torch.nn as nn +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.models.config import VerifyAndUpdateConfig from .interfaces_base import VllmModelForPooling, is_pooling_model @@ -16,6 +19,8 @@ _T = TypeVar("_T", bound=type[nn.Module]) +logger = init_logger(__name__) + _GENERATE_SUFFIXES = [ "ForCausalLM", "ForConditionalGeneration", @@ -24,6 +29,130 @@ ] +def _load_weights_to_linear(state_dict: dict, linear: nn.Linear) -> bool: + """Load weights from a state dict into a linear layer.""" + weight = None + bias = None + + for weight_key in ["linear.weight", "dense.weight", "weight"]: + if weight_key in state_dict: + weight = state_dict[weight_key] + break + + for bias_key in ["linear.bias", "dense.bias", "bias"]: + if bias_key in state_dict: + bias = state_dict[bias_key] + break + + if weight is None: + return False + + try: + with torch.no_grad(): + # Ensure weights are float32 for numerical stability + linear.weight.copy_(weight.to(torch.float32)) + if linear.bias is not None and bias is not None: + linear.bias.copy_(bias.to(torch.float32)) + return True + except RuntimeError as e: + logger.warning("Failed to load weights into linear layer: %s", e) + return False + + +def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]: + """Load Sentence-Transformers Dense projection layers.""" + from vllm.transformers_utils.config import (get_hf_file_bytes, + get_hf_file_to_dict) + + model_path = model_config.model + revision = model_config.revision + + # Read modules.json + modules = get_hf_file_to_dict("modules.json", model_path, revision) + + # Handle dict format (some ST variants) + if isinstance(modules, dict): + modules = modules.get("modules", []) + if not isinstance(modules, list): + return None + + # Filter Dense modules + dense_entries = [ + m for m in modules + if m.get("type") == "sentence_transformers.models.Dense" + ] + if not dense_entries: + return None + + # Build projection layer sequence + layers = [] + for entry in dense_entries: + folder = entry.get("path") + if not folder: + continue + + # Read config + cfg = get_hf_file_to_dict(f"{folder}/config.json", model_path, + revision) + if not cfg: + continue + + in_features = cfg.get("in_features") + out_features = cfg.get("out_features") + if in_features is None or out_features is None: + continue + + use_bias = cfg.get("bias", True) + # Create linear layer with float32 for numerical stability + linear = nn.Linear(in_features, out_features, bias=use_bias) + + # Try to load weights - first safetensors, then pytorch_model.bin + weight_loaded = False + + # Try safetensors + try: + b = get_hf_file_bytes(f"{folder}/model.safetensors", model_path, + revision) + if b is not None: + import io + + from safetensors.torch import load as st_load + sd = st_load(b) + weight_loaded = _load_weights_to_linear(sd, linear) + except (OSError, ImportError, ValueError) as e: + logger.debug("Failed to load safetensors from %s: %s", folder, e) + + if not weight_loaded: + try: + b = get_hf_file_bytes(f"{folder}/pytorch_model.bin", + model_path, revision) + if b is not None: + import io + sd = torch.load(io.BytesIO(b), map_location="cpu") + weight_loaded = _load_weights_to_linear(sd, linear) + except (OSError, torch.serialization.UnpicklingError, RuntimeError, + ValueError) as e: + logger.debug("Failed to load pytorch_model.bin from %s: %s", + folder, e) + + if not weight_loaded: + logger.warning("Failed to load weights for Dense layer in %s", + folder) + + layers.append(linear) + activation_name = cfg.get("activation_function") + if activation_name is not None: + layers.append(get_act_fn(activation_name)) + + if not layers: + return None + + # Ensure the entire module uses float32 + projector = nn.Sequential(*layers) + projector = projector.to(dtype=torch.float32) + return projector + + def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str: model_name = orig_model_name @@ -123,11 +252,15 @@ def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler( - { - "encode": Pooler.for_encode(pooler_config), - "embed": Pooler.for_embed(pooler_config), - }, ) + # Load ST projector for embed task only + projector = _load_st_projector(vllm_config.model_config) + + self.pooler = DispatchPooler({ + "encode": + Pooler.for_encode(pooler_config), + "embed": + Pooler.for_embed(pooler_config, projector=projector), + }) ModelForEmbedding.__name__ = \ _get_pooling_model_name(cls.__name__, "ForEmbedding") diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 6638f06f9826..32d3038a36fb 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -28,6 +28,7 @@ from vllm.sequence import IntermediateTensors from vllm.tasks import PoolingTask +from .adapters import _load_st_projector from .interfaces import (SupportsCrossEncoding, SupportsQuant, default_pooling_type) from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix @@ -456,7 +457,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = self._build_model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) - self.pooler = self._build_pooler(pooler_config) + self.pooler = self._build_pooler(pooler_config, vllm_config) def forward( self, @@ -488,10 +489,15 @@ def _build_model(self, prefix=prefix, embedding_class=BertEmbedding) - def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: + def _build_pooler(self, pooler_config: PoolerConfig, + vllm_config: VllmConfig) -> Pooler: + projector = _load_st_projector(vllm_config.model_config) + return DispatchPooler({ - "encode": Pooler.for_encode(pooler_config), - "embed": Pooler.for_embed(pooler_config), + "encode": + Pooler.for_encode(pooler_config), + "embed": + Pooler.for_embed(pooler_config, projector=projector), }) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index d8c964fb2a4a..ecaa6a5f9fd2 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -908,3 +908,33 @@ def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int: exc_info=e) return max_position_embeddings + + +def get_hf_file_bytes(file_name: str, + model: Union[str, Path], + revision: Optional[str] = 'main') -> Optional[bytes]: + """Get file contents from HuggingFace repository as bytes.""" + file_path = try_get_local_file(model=model, + file_name=file_name, + revision=revision) + + if file_path is None: + try: + hf_hub_file = hf_hub_download(model, + file_name, + revision=revision, + token=_get_hf_token()) + file_path = Path(hf_hub_file) + except (OSError, ValueError) as e: + logger.debug("Failed to download %s from HF: %s", file_name, e) + return None + + if file_path is not None and file_path.is_file(): + try: + with open(file_path, 'rb') as file: + return file.read() + except OSError as e: + logger.debug("Failed to read file %s: %s", file_path, e) + return None + + return None \ No newline at end of file