Skip to content

Conversation

FFFfff1FFFfff
Copy link
Contributor

@FFFfff1FFFfff FFFfff1FFFfff commented Aug 16, 2025

Purpose:

This PR adds automatic support for Sentence-Transformers Dense projection layers in vLLM, enabling proper handling of models that require dimension transformation (e.g., 1024→1792) during embedding generation.

Resolves the following issues:

  • Missing Dense projection functionality for ST models in vLLM
  • Incorrect output dimensions (1024 instead of expected 1792 for models like TencentBAC/Conan-embedding-v1)
  • Ensures numerical consistency with HuggingFace Sentence-Transformers implementation

Key Modifications

  • pooler.py: Enhanced EmbeddingPoolerHead with projector support, device sync, and dimension validation
  • adapters.py: Added _load_st_projector() to detect and load Dense layers from ST models
  • bert.py: Integrated ST projector support in BERT embedding models
  • config.py: Added get_hf_file_bytes() utility for loading model files

New Version Improvements

  • Simplified code: Removed complex token encoding logic, better error handling with specific exceptions
  • Enhanced testing: Added test_embed_models_mteb for MTEB compatibility validation
  • Numerical stability: Explicit float32 handling for projection operations

Test
python -m pytest tests/models/language/pooling/test_st_projector.py -v

Test Result

tests/models/language/pooling/test_st_projector.py::test_embed_models_mteb PASSED
tests/models/language/pooling/test_st_projector.py::test_st_projector_loading PASSED
tests/models/language/pooling/test_st_projector.py::test_compare_with_hf_dimensions PASSED
tests/models/language/pooling/test_st_projector.py::test_embedding_numerical_similarity PASSED
tests/models/language/pooling/test_st_projector.py::test_embedding_quality_checks PASSED

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Sentence-Transformers Dense projection layers, a crucial enhancement for handling a broader range of embedding models. The implementation is well-structured, incorporating new helper functions for loading weights and validating configurations, and includes a comprehensive test suite. I've identified one critical issue concerning the handling of empty inputs which could lead to a server crash. Please see the detailed comment below.

@mergify mergify bot added the ci/build label Aug 16, 2025
@DarkLight1337
Copy link
Member

cc @noooop

@DarkLight1337
Copy link
Member

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Sentence-Transformers Dense projection layers, which is a valuable addition for handling a wider range of embedding models. The implementation appears robust, and the new tests are comprehensive, covering loading, dimensionality, numerical similarity, and quality checks. However, I've identified a critical issue in vllm/model_executor/layers/pooler.py where a change inadvertently alters the behavior of reward models by replacing RewardPoolerHead with EmbeddingPoolerHead for the 'encode' task. This could lead to incorrect outputs for reward modeling tasks and needs to be addressed.

Comment on lines 607 to +608
elif pooler_config.task == "encode":
head = RewardPoolerHead()
head = EmbeddingPoolerHead() # no projector
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change replaces RewardPoolerHead with EmbeddingPoolerHead for the encode task. This appears to be a regression, as the encode task is used by reward models, which typically output unnormalized scores (often processed with sigmoid or softmax). RewardPoolerHead correctly handles this. In contrast, EmbeddingPoolerHead is designed for embedding models and applies normalization, which is not the desired behavior for reward models. This change will likely break reward model functionality.

Suggested change
elif pooler_config.task == "encode":
head = RewardPoolerHead()
head = EmbeddingPoolerHead() # no projector
elif pooler_config.task == "encode":
head = RewardPoolerHead()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue addressed. Ready for review. Thanks!! @DarkLight1337

Comment on lines +63 to +76
@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS)
def test_embed_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None:
"""MTEB test for ST projector models to detect numerical issues."""
vllm_extra_kwargs: dict[str, Any] = {}
if model_info.architecture == "BertModel":
# Ensure BertEmbeddingModel is used for embedding models
vllm_extra_kwargs["trust_remote_code"] = True

mteb_test_embed_models(hf_runner,
vllm_runner,
model_info,
vllm_extra_kwargs,
atol=MTEB_EMBED_TOL)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please keep only test_embed_models_mteb, most of the tests in this folder are like this

vllm_extra_kwargs: dict[str, Any] = {}
if model_info.architecture == "BertModel":
# Ensure BertEmbeddingModel is used for embedding models
vllm_extra_kwargs["trust_remote_code"] = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hf_runner, vllm_runner already include trust_remote_code= True

Comment on lines +465 to +482

def _sync_projector_to_ref(self, ref_tensor: torch.Tensor) -> None:
"""Ensure projector is on correct device with float32 dtype."""
if self.projector is None:
return

projector = cast(nn.Module, self.projector)
try:
proj_device = next(projector.parameters()).device
if proj_device != ref_tensor.device:
projector.to(device=ref_tensor.device, dtype=torch.float32)
# Ensure all parameters are float32
for param in projector.parameters():
param.data = param.data.to(torch.float32)
except StopIteration:
# Empty projector, skip device check
pass

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is needed.

Comment on lines +483 to +502
def _validate_projector_dimensions(self, ref_tensor: torch.Tensor) -> None:
"""Validate projector input dimensions match pooled output."""
if self.projector is None:
return

projector = cast(nn.Module, self.projector)
first_linear = None
for module in projector.modules():
if isinstance(module, nn.Linear):
first_linear = module
break

if first_linear is not None:
expected_dim = first_linear.in_features
actual_dim = ref_tensor.shape[-1]
if expected_dim != actual_dim:
raise ValueError(
f"Dimension mismatch: Dense projector expects "
f"input dim {expected_dim}, but pooled output "
f"has dim {actual_dim}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's a need for dynamic projector_dimensions check

Comment on lines +509 to +510
if isinstance(pooled_data, list) and len(pooled_data) == 0:
pass # Skip projection for empty inputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This situation should not happen.

Comment on lines +46 to +59

if weight is None:
return False

try:
with torch.no_grad():
# Ensure weights are float32 for numerical stability
linear.weight.copy_(weight.to(torch.float32))
if linear.bias is not None and bias is not None:
linear.bias.copy_(bias.to(torch.float32))
return True
except RuntimeError as e:
logger.warning("Failed to load weights into linear layer: %s", e)
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use weight_loader, reference

token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
score_weight = model.lm_head.weight.data[token_ids]
param = model.score.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, score_weight)

Comment on lines +104 to +108

use_bias = cfg.get("bias", True)
# Create linear layer with float32 for numerical stability
linear = nn.Linear(in_features, out_features, bias=use_bias)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should set float32 here

Comment on lines +109 to +141
# Try to load weights - first safetensors, then pytorch_model.bin
weight_loaded = False

# Try safetensors
try:
b = get_hf_file_bytes(f"{folder}/model.safetensors", model_path,
revision)
if b is not None:
import io

from safetensors.torch import load as st_load
sd = st_load(b)
weight_loaded = _load_weights_to_linear(sd, linear)
except (OSError, ImportError, ValueError) as e:
logger.debug("Failed to load safetensors from %s: %s", folder, e)

if not weight_loaded:
try:
b = get_hf_file_bytes(f"{folder}/pytorch_model.bin",
model_path, revision)
if b is not None:
import io
sd = torch.load(io.BytesIO(b), map_location="cpu")
weight_loaded = _load_weights_to_linear(sd, linear)
except (OSError, torch.serialization.UnpicklingError, RuntimeError,
ValueError) as e:
logger.debug("Failed to load pytorch_model.bin from %s: %s",
folder, e)

if not weight_loaded:
logger.warning("Failed to load weights for Dense layer in %s",
folder)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weight_loader will do these automatically

logger.debug("Failed to read file %s: %s", file_path, e)
return None

return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A blank line is needed at the end.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review! My server crashed a few days ago, but I’ll get it fixed soon.

Copy link

mergify bot commented Aug 20, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @FFFfff1FFFfff.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants