Skip to content

Enable splade embeddings for Python backend #493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions backends/python/server/text_embeddings_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from transformers.models.bert import BertConfig

from text_embeddings_server.models.model import Model
from text_embeddings_server.models.masked_model import MaskedLanguageModel
from text_embeddings_server.models.default_model import DefaultModel
from text_embeddings_server.models.classification_model import ClassificationModel
from text_embeddings_server.utils.device import get_device, use_ipex
Expand Down Expand Up @@ -53,6 +54,14 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
and FLASH_ATTENTION
):
if pool != "cls":
if config.architectures[0].endswith("ForMaskedLM"):
return MaskedLanguageModel(
model_path,
device,
datatype,
pool,
trust_remote=TRUST_REMOTE_CODE,
)
return DefaultModel(
model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE
)
Expand All @@ -61,6 +70,10 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
return ClassificationModel(
model_path, device, datatype, trust_remote=TRUST_REMOTE_CODE
)
elif config.architectures[0].endswith("ForMaskedLM"):
return MaskedLanguageModel(
model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE
)
else:
return DefaultModel(
model_path,
Expand All @@ -84,6 +97,10 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
datatype,
trust_remote=TRUST_REMOTE_CODE,
)
elif config.architectures[0].endswith("ForMaskedLM"):
return MaskedLanguageModel(
model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE
)
else:
model_handle = DefaultModel(
model_path,
Expand All @@ -102,6 +119,10 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
datatype,
trust_remote=TRUST_REMOTE_CODE,
)
elif config.architectures[0].endswith("ForMaskedLM"):
return MaskedLanguageModel(
model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE
)
else:
return DefaultModel(
model_path,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import inspect
import torch

from pathlib import Path
from typing import Type, List
from transformers import AutoModelForMaskedLM
from opentelemetry import trace

from text_embeddings_server.models import Model
from text_embeddings_server.models.types import PaddedBatch, Embedding, Score
from text_embeddings_server.models.pooling import DefaultPooling, SpladePooling

tracer = trace.get_tracer(__name__)


class MaskedLanguageModel(Model):
def __init__(
self,
model_path: Path,
device: torch.device,
dtype: torch.dtype,
pool: str,
trust_remote: bool = False,
):
model = (
AutoModelForMaskedLM.from_pretrained(
model_path, trust_remote_code=trust_remote
)
.to(dtype)
.to(device)
)
self.hidden_size = model.config.hidden_size
self.vocab_size = model.config.vocab_size
self.pooling_mode = pool
if pool == "splade":
self.pooling = SpladePooling()
else:
self.pooling = DefaultPooling(self.hidden_size, pooling_mode=pool)

self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None)
is not None
)
self.has_token_type_ids = (
inspect.signature(model.forward).parameters.get("token_type_ids", None)
is not None
)

super(MaskedLanguageModel, self).__init__(
model=model, dtype=dtype, device=device
)

@property
def batch_type(self) -> Type[PaddedBatch]:
return PaddedBatch

@tracer.start_as_current_span("embed")
def embed(self, batch: PaddedBatch) -> List[Embedding]:
kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask}
if self.has_token_type_ids:
kwargs["token_type_ids"] = batch.token_type_ids
if self.has_position_ids:
kwargs["position_ids"] = batch.position_ids
output = self.model(**kwargs)
embedding = self.pooling.forward(output, batch.attention_mask)
cpu_results = embedding.view(-1).tolist()

step_size = embedding.shape[-1]
return [
Embedding(values=cpu_results[i * step_size : (i + 1) * step_size])
for i in range(len(batch))
]

@tracer.start_as_current_span("predict")
def predict(self, batch: PaddedBatch) -> List[Score]:
pass
40 changes: 40 additions & 0 deletions backends/python/server/text_embeddings_server/models/pooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from abc import ABC, abstractmethod

import torch
from opentelemetry import trace
from sentence_transformers.models import Pooling
from torch import Tensor

tracer = trace.get_tracer(__name__)


class _Pooling(ABC):
@abstractmethod
def forward(self, model_output, attention_mask) -> Tensor:
pass


class DefaultPooling(_Pooling):
def __init__(self, hidden_size, pooling_mode) -> None:
assert (
pooling_mode != "splade"
), "Splade pooling is not supported for DefaultPooling"
self.pooling = Pooling(hidden_size, pooling_mode=pooling_mode)

@tracer.start_as_current_span("pooling")
def forward(self, model_output, attention_mask) -> Tensor:
pooling_features = {
"token_embeddings": model_output[0],
"attention_mask": attention_mask,
}
return self.pooling.forward(pooling_features)["sentence_embedding"]


class SpladePooling(_Pooling):
@tracer.start_as_current_span("pooling")
def forward(self, model_output, attention_mask) -> Tensor:
# Implement Splade pooling
hidden_states = torch.relu(model_output[0])
hidden_states = (1 + hidden_states).log()
hidden_states = torch.mul(hidden_states, attention_mask.unsqueeze(-1))
return hidden_states.max(dim=1).values
4 changes: 1 addition & 3 deletions backends/python/src/management.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ impl BackendProcess {
Pool::Cls => "cls",
Pool::Mean => "mean",
Pool::LastToken => "lasttoken",
Pool::Splade => {
return Err(BackendError::Start(format!("{pool:?} is not supported")));
}
Pool::Splade => "splade",
};

// Process args
Expand Down