-
-
Notifications
You must be signed in to change notification settings - Fork 10.1k
[Bugfix] Fix Dense module loading for sentence-transformers embedding models (simplified version) #23019
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Bugfix] Fix Dense module loading for sentence-transformers embedding models (simplified version) #23019
Changes from all commits
fdba6a9
d2c5380
d4f9655
162fed3
48a384c
75d5d95
b662dd0
1f67604
a3c6de8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
from typing import Any | ||
|
||
import numpy as np | ||
import pytest | ||
from scipy.spatial.distance import cosine | ||
|
||
from ...utils import EmbedModelInfo | ||
from .mteb_utils import MTEB_EMBED_TOL, mteb_test_embed_models | ||
|
||
|
||
def _get_vllm_embeddings(vllm_runner, model_info: EmbedModelInfo, | ||
test_texts: list[str]): | ||
"""Get embeddings from vLLM.""" | ||
vllm_extra_kwargs: dict[str, Any] = {} | ||
if model_info.architecture == "GteNewModel": | ||
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} | ||
|
||
with vllm_runner( | ||
model_info.name, | ||
runner="pooling", | ||
max_model_len=None, | ||
trust_remote_code=True, | ||
**vllm_extra_kwargs, | ||
) as vllm_model: | ||
embeddings = vllm_model.encode(test_texts) | ||
|
||
# Extract tensor/numpy data | ||
data = [] | ||
for emb in embeddings: | ||
if hasattr(emb, "outputs"): | ||
data.append(emb.outputs.data.cpu().numpy()) | ||
else: | ||
data.append(emb.cpu().numpy() if hasattr(emb, "cpu") else emb) | ||
return np.array(data) | ||
|
||
|
||
def _get_hf_embeddings(hf_runner, model_info: EmbedModelInfo, | ||
test_texts: list[str]): | ||
"""Get embeddings from HuggingFace ST interface.""" | ||
with hf_runner( | ||
model_info.name, | ||
is_sentence_transformer=True, | ||
dtype="float32", | ||
) as hf_model: | ||
embeddings = hf_model.encode(test_texts) | ||
if hasattr(embeddings, "cpu"): | ||
return embeddings.cpu().numpy() | ||
return np.array(embeddings) | ||
|
||
|
||
# ST models with projector (Dense) layers | ||
ST_PROJECTOR_MODELS = [ | ||
EmbedModelInfo( | ||
"TencentBAC/Conan-embedding-v1", | ||
architecture="BertModel", | ||
enable_test=True, | ||
), | ||
] | ||
|
||
|
||
@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS) | ||
def test_embed_models_mteb(hf_runner, vllm_runner, | ||
model_info: EmbedModelInfo) -> None: | ||
"""MTEB test for ST projector models to detect numerical issues.""" | ||
vllm_extra_kwargs: dict[str, Any] = {} | ||
if model_info.architecture == "BertModel": | ||
# Ensure BertEmbeddingModel is used for embedding models | ||
vllm_extra_kwargs["trust_remote_code"] = True | ||
|
||
mteb_test_embed_models(hf_runner, | ||
vllm_runner, | ||
model_info, | ||
vllm_extra_kwargs, | ||
atol=MTEB_EMBED_TOL) | ||
Comment on lines
+63
to
+76
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please keep only test_embed_models_mteb, most of the tests in this folder are like this |
||
|
||
|
||
@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS) | ||
def test_st_projector_loading(vllm_runner, model_info: EmbedModelInfo) -> None: | ||
"""Ensure projector models load and output expected dim.""" | ||
if not model_info.enable_test: | ||
pytest.skip("Skipping test.") | ||
|
||
test_texts = ["This is a test sentence."] | ||
embeddings_data = _get_vllm_embeddings(vllm_runner, model_info, test_texts) | ||
|
||
actual_dim = embeddings_data.shape[-1] | ||
expected_dim = 1792 | ||
assert actual_dim == expected_dim, ( | ||
f"Expected {expected_dim}, got {actual_dim}") | ||
|
||
|
||
@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS) | ||
def test_compare_with_hf_dimensions(hf_runner, vllm_runner, | ||
model_info: EmbedModelInfo) -> None: | ||
"""Compare embedding dimensions between vLLM and HuggingFace.""" | ||
if not model_info.enable_test: | ||
pytest.skip("Skipping test.") | ||
|
||
test_texts = ["This is a test sentence for dimension comparison."] | ||
|
||
vllm_data = _get_vllm_embeddings(vllm_runner, model_info, test_texts) | ||
hf_data = _get_hf_embeddings(hf_runner, model_info, test_texts) | ||
|
||
vllm_dim = vllm_data.shape[-1] | ||
hf_dim = hf_data.shape[-1] | ||
|
||
assert vllm_dim == hf_dim, ("Embedding dim mismatch: " | ||
f"vLLM {vllm_dim} vs HF {hf_dim}") | ||
print(f"✓ Embedding dimensions match: {vllm_dim}") | ||
|
||
|
||
@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS) | ||
def test_embedding_numerical_similarity(hf_runner, vllm_runner, | ||
model_info: EmbedModelInfo) -> None: | ||
"""Numerical similarity between vLLM and HF embeddings.""" | ||
if not model_info.enable_test: | ||
pytest.skip("Skipping test.") | ||
|
||
test_texts = [ | ||
"This is a test sentence for numerical comparison.", | ||
"Another sentence to verify embedding quality.", | ||
"机器学习是人工智能的一个重要分支。", # Chinese test | ||
] | ||
|
||
vllm_data = _get_vllm_embeddings(vllm_runner, model_info, test_texts) | ||
hf_data = _get_hf_embeddings(hf_runner, model_info, test_texts) | ||
|
||
assert vllm_data.shape == hf_data.shape, ( | ||
"Shape mismatch: " | ||
f"vLLM {vllm_data.shape} vs HF {hf_data.shape}") | ||
|
||
print(f"Embedding shape: {vllm_data.shape}") | ||
print(f"Embedding dimension: {vllm_data.shape[-1]}") | ||
|
||
similarities = [] | ||
for i, text in enumerate(test_texts): | ||
vllm_emb = vllm_data[i] | ||
hf_emb = hf_data[i] | ||
|
||
similarity = 1 - cosine(vllm_emb, hf_emb) | ||
similarities.append(similarity) | ||
|
||
preview = text[:50] + ("..." if len(text) > 50 else "") | ||
print(f"Text {i + 1}: '{preview}'") | ||
print(f" Cosine similarity: {similarity:.6f}") | ||
|
||
min_similarity = 0.95 | ||
assert similarity > min_similarity, ( | ||
f"Text {i + 1} similarity too low: " | ||
f"{similarity:.6f} < {min_similarity}\n" | ||
f"vLLM norm: {np.linalg.norm(vllm_emb):.6f}, " | ||
f"HF norm: {np.linalg.norm(hf_emb):.6f}") | ||
|
||
avg_similarity = np.mean(similarities) | ||
print(f"\nAverage cosine similarity: {avg_similarity:.6f}") | ||
|
||
assert avg_similarity > 0.98, ( | ||
f"Average similarity too low: {avg_similarity:.6f} < 0.98") | ||
print("✓ All numerical similarity tests passed!") | ||
|
||
|
||
@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS) | ||
def test_embedding_quality_checks(vllm_runner, | ||
model_info: EmbedModelInfo) -> None: | ||
"""Basic quality checks: non-zero, non-constant, distinct.""" | ||
if not model_info.enable_test: | ||
pytest.skip("Skipping test.") | ||
|
||
test_texts = [ | ||
"First test sentence.", | ||
"Second different sentence.", | ||
"Completely different content here.", | ||
] | ||
|
||
embeddings_data = _get_vllm_embeddings(vllm_runner, model_info, test_texts) | ||
|
||
print(f"Embeddings shape: {embeddings_data.shape}") | ||
|
||
# Non-zero and non-constant | ||
for i, emb in enumerate(embeddings_data): | ||
norm = np.linalg.norm(emb) | ||
print(f"Embedding {i + 1} L2 norm: {norm:.6f}") | ||
assert norm > 1e-6, ( | ||
f"Embedding {i + 1} too close to zero: norm={norm}") | ||
|
||
std = np.std(emb) | ||
print(f"Embedding {i + 1} std: {std:.6f}") | ||
assert std > 1e-6, ( | ||
f"Embedding {i + 1} too close to constant: std={std}") | ||
|
||
# Different texts should differ | ||
for i in range(len(embeddings_data)): | ||
for j in range(i + 1, len(embeddings_data)): | ||
sim = 1 - cosine(embeddings_data[i], embeddings_data[j]) | ||
print(f"Similarity between text {i + 1} and {j + 1}: {sim:.6f}") | ||
assert sim < 0.99, ("Embeddings too similar: " | ||
f"{i + 1} vs {j + 1} -> {sim:.6f}") | ||
|
||
print("✓ All embedding quality checks passed!") |
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -5,7 +5,7 @@ | |||||||||||
from dataclasses import dataclass | ||||||||||||
from enum import IntEnum | ||||||||||||
from itertools import groupby | ||||||||||||
from typing import Callable, Optional, TypeVar, Union | ||||||||||||
from typing import Callable, Optional, TypeVar, Union, cast | ||||||||||||
|
||||||||||||
import torch | ||||||||||||
import torch.nn as nn | ||||||||||||
|
@@ -77,13 +77,17 @@ def for_encode(pooler_config: PoolerConfig): | |||||||||||
return SimplePooler.from_config(resolved_config) | ||||||||||||
|
||||||||||||
@staticmethod | ||||||||||||
def for_embed(pooler_config: PoolerConfig): | ||||||||||||
def for_embed( | ||||||||||||
pooler_config: PoolerConfig, | ||||||||||||
*, | ||||||||||||
projector: Optional[nn.Module] = None, | ||||||||||||
): | ||||||||||||
resolved_config = ResolvedPoolingConfig.from_config( | ||||||||||||
task="embed", | ||||||||||||
pooler_config=pooler_config, | ||||||||||||
) | ||||||||||||
|
||||||||||||
return SimplePooler.from_config(resolved_config) | ||||||||||||
return SimplePooler.from_config(resolved_config, projector=projector) | ||||||||||||
|
||||||||||||
@staticmethod | ||||||||||||
def for_classify( | ||||||||||||
|
@@ -454,12 +458,77 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], | |||||||||||
|
||||||||||||
class EmbeddingPoolerHead(PoolerHead): | ||||||||||||
|
||||||||||||
def __init__(self) -> None: | ||||||||||||
def __init__(self, projector: Optional[nn.Module] = None) -> None: | ||||||||||||
super().__init__(activation=PoolerNormalize()) | ||||||||||||
self.projector = projector | ||||||||||||
self._projector_dim_checked = False | ||||||||||||
|
||||||||||||
def _sync_projector_to_ref(self, ref_tensor: torch.Tensor) -> None: | ||||||||||||
"""Ensure projector is on correct device with float32 dtype.""" | ||||||||||||
if self.projector is None: | ||||||||||||
return | ||||||||||||
|
||||||||||||
projector = cast(nn.Module, self.projector) | ||||||||||||
try: | ||||||||||||
proj_device = next(projector.parameters()).device | ||||||||||||
if proj_device != ref_tensor.device: | ||||||||||||
projector.to(device=ref_tensor.device, dtype=torch.float32) | ||||||||||||
# Ensure all parameters are float32 | ||||||||||||
for param in projector.parameters(): | ||||||||||||
param.data = param.data.to(torch.float32) | ||||||||||||
except StopIteration: | ||||||||||||
# Empty projector, skip device check | ||||||||||||
pass | ||||||||||||
|
||||||||||||
Comment on lines
+465
to
+482
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this is needed. |
||||||||||||
def _validate_projector_dimensions(self, ref_tensor: torch.Tensor) -> None: | ||||||||||||
"""Validate projector input dimensions match pooled output.""" | ||||||||||||
if self.projector is None: | ||||||||||||
return | ||||||||||||
|
||||||||||||
projector = cast(nn.Module, self.projector) | ||||||||||||
first_linear = None | ||||||||||||
for module in projector.modules(): | ||||||||||||
if isinstance(module, nn.Linear): | ||||||||||||
first_linear = module | ||||||||||||
break | ||||||||||||
|
||||||||||||
if first_linear is not None: | ||||||||||||
expected_dim = first_linear.in_features | ||||||||||||
actual_dim = ref_tensor.shape[-1] | ||||||||||||
if expected_dim != actual_dim: | ||||||||||||
raise ValueError( | ||||||||||||
f"Dimension mismatch: Dense projector expects " | ||||||||||||
f"input dim {expected_dim}, but pooled output " | ||||||||||||
f"has dim {actual_dim}") | ||||||||||||
Comment on lines
+483
to
+502
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think there's a need for dynamic projector_dimensions check |
||||||||||||
|
||||||||||||
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], | ||||||||||||
pooling_metadata: PoolingMetadata): | ||||||||||||
|
||||||||||||
# Apply ST projector | ||||||||||||
if self.projector is not None: | ||||||||||||
if isinstance(pooled_data, list) and len(pooled_data) == 0: | ||||||||||||
pass # Skip projection for empty inputs | ||||||||||||
Comment on lines
+509
to
+510
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This situation should not happen. |
||||||||||||
else: | ||||||||||||
projector = cast(nn.Module, self.projector) | ||||||||||||
ref = pooled_data[0] if isinstance(pooled_data, | ||||||||||||
list) else pooled_data | ||||||||||||
|
||||||||||||
self._sync_projector_to_ref(ref) | ||||||||||||
|
||||||||||||
if not self._projector_dim_checked: | ||||||||||||
self._validate_projector_dimensions(ref) | ||||||||||||
self._projector_dim_checked = True | ||||||||||||
|
||||||||||||
def _proj(x: torch.Tensor) -> torch.Tensor: | ||||||||||||
orig_dtype = x.dtype | ||||||||||||
y = projector(x.to(torch.float32)) | ||||||||||||
return y.to(orig_dtype) | ||||||||||||
|
||||||||||||
if isinstance(pooled_data, torch.Tensor): | ||||||||||||
pooled_data = _proj(pooled_data) | ||||||||||||
else: | ||||||||||||
pooled_data = [_proj(t) for t in pooled_data] | ||||||||||||
|
||||||||||||
pooling_params = get_pooling_params(pooling_metadata) | ||||||||||||
|
||||||||||||
# for matryoshka representation | ||||||||||||
|
@@ -530,12 +599,13 @@ class SimplePooler(Pooler): | |||||||||||
def from_config( | ||||||||||||
cls, | ||||||||||||
pooler_config: ResolvedPoolingConfig, | ||||||||||||
projector: Optional[nn.Module] = None, | ||||||||||||
) -> "SimplePooler": | ||||||||||||
pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type) | ||||||||||||
if pooler_config.task == "embed": | ||||||||||||
head = EmbeddingPoolerHead() | ||||||||||||
head = EmbeddingPoolerHead(projector=projector) | ||||||||||||
elif pooler_config.task == "encode": | ||||||||||||
head = RewardPoolerHead() | ||||||||||||
head = EmbeddingPoolerHead() # no projector | ||||||||||||
Comment on lines
607
to
+608
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change replaces
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Issue addressed. Ready for review. Thanks!! @DarkLight1337 |
||||||||||||
else: | ||||||||||||
raise NotImplementedError(f"Unknown task: {pooler_config.task}") | ||||||||||||
return cls(pooling, head) | ||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hf_runner, vllm_runner already include trust_remote_code= True