diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index e5cbf72c..72ec500d 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -8,6 +8,7 @@ from transformers.models.bert import BertConfig from text_embeddings_server.models.model import Model +from text_embeddings_server.models.masked_model import MaskedLanguageModel from text_embeddings_server.models.default_model import DefaultModel from text_embeddings_server.models.classification_model import ClassificationModel from text_embeddings_server.utils.device import get_device, use_ipex @@ -53,6 +54,14 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): and FLASH_ATTENTION ): if pool != "cls": + if config.architectures[0].endswith("ForMaskedLM"): + return MaskedLanguageModel( + model_path, + device, + datatype, + pool, + trust_remote=TRUST_REMOTE_CODE, + ) return DefaultModel( model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE ) @@ -61,6 +70,10 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): return ClassificationModel( model_path, device, datatype, trust_remote=TRUST_REMOTE_CODE ) + elif config.architectures[0].endswith("ForMaskedLM"): + return MaskedLanguageModel( + model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE + ) else: return DefaultModel( model_path, @@ -84,6 +97,10 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): datatype, trust_remote=TRUST_REMOTE_CODE, ) + elif config.architectures[0].endswith("ForMaskedLM"): + return MaskedLanguageModel( + model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE + ) else: model_handle = DefaultModel( model_path, @@ -102,6 +119,10 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): datatype, trust_remote=TRUST_REMOTE_CODE, ) + elif config.architectures[0].endswith("ForMaskedLM"): + return MaskedLanguageModel( + model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE + ) else: return DefaultModel( model_path, diff --git a/backends/python/server/text_embeddings_server/models/masked_model.py b/backends/python/server/text_embeddings_server/models/masked_model.py new file mode 100644 index 00000000..ea622255 --- /dev/null +++ b/backends/python/server/text_embeddings_server/models/masked_model.py @@ -0,0 +1,76 @@ +import inspect +import torch + +from pathlib import Path +from typing import Type, List +from transformers import AutoModelForMaskedLM +from opentelemetry import trace + +from text_embeddings_server.models import Model +from text_embeddings_server.models.types import PaddedBatch, Embedding, Score +from text_embeddings_server.models.pooling import DefaultPooling, SpladePooling + +tracer = trace.get_tracer(__name__) + + +class MaskedLanguageModel(Model): + def __init__( + self, + model_path: Path, + device: torch.device, + dtype: torch.dtype, + pool: str, + trust_remote: bool = False, + ): + model = ( + AutoModelForMaskedLM.from_pretrained( + model_path, trust_remote_code=trust_remote + ) + .to(dtype) + .to(device) + ) + self.hidden_size = model.config.hidden_size + self.vocab_size = model.config.vocab_size + self.pooling_mode = pool + if pool == "splade": + self.pooling = SpladePooling() + else: + self.pooling = DefaultPooling(self.hidden_size, pooling_mode=pool) + + self.has_position_ids = ( + inspect.signature(model.forward).parameters.get("position_ids", None) + is not None + ) + self.has_token_type_ids = ( + inspect.signature(model.forward).parameters.get("token_type_ids", None) + is not None + ) + + super(MaskedLanguageModel, self).__init__( + model=model, dtype=dtype, device=device + ) + + @property + def batch_type(self) -> Type[PaddedBatch]: + return PaddedBatch + + @tracer.start_as_current_span("embed") + def embed(self, batch: PaddedBatch) -> List[Embedding]: + kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask} + if self.has_token_type_ids: + kwargs["token_type_ids"] = batch.token_type_ids + if self.has_position_ids: + kwargs["position_ids"] = batch.position_ids + output = self.model(**kwargs) + embedding = self.pooling.forward(output, batch.attention_mask) + cpu_results = embedding.view(-1).tolist() + + step_size = embedding.shape[-1] + return [ + Embedding(values=cpu_results[i * step_size : (i + 1) * step_size]) + for i in range(len(batch)) + ] + + @tracer.start_as_current_span("predict") + def predict(self, batch: PaddedBatch) -> List[Score]: + pass diff --git a/backends/python/server/text_embeddings_server/models/pooling.py b/backends/python/server/text_embeddings_server/models/pooling.py new file mode 100644 index 00000000..43f77b14 --- /dev/null +++ b/backends/python/server/text_embeddings_server/models/pooling.py @@ -0,0 +1,40 @@ +from abc import ABC, abstractmethod + +import torch +from opentelemetry import trace +from sentence_transformers.models import Pooling +from torch import Tensor + +tracer = trace.get_tracer(__name__) + + +class _Pooling(ABC): + @abstractmethod + def forward(self, model_output, attention_mask) -> Tensor: + pass + + +class DefaultPooling(_Pooling): + def __init__(self, hidden_size, pooling_mode) -> None: + assert ( + pooling_mode != "splade" + ), "Splade pooling is not supported for DefaultPooling" + self.pooling = Pooling(hidden_size, pooling_mode=pooling_mode) + + @tracer.start_as_current_span("pooling") + def forward(self, model_output, attention_mask) -> Tensor: + pooling_features = { + "token_embeddings": model_output[0], + "attention_mask": attention_mask, + } + return self.pooling.forward(pooling_features)["sentence_embedding"] + + +class SpladePooling(_Pooling): + @tracer.start_as_current_span("pooling") + def forward(self, model_output, attention_mask) -> Tensor: + # Implement Splade pooling + hidden_states = torch.relu(model_output[0]) + hidden_states = (1 + hidden_states).log() + hidden_states = torch.mul(hidden_states, attention_mask.unsqueeze(-1)) + return hidden_states.max(dim=1).values diff --git a/backends/python/src/management.rs b/backends/python/src/management.rs index 1ae004fd..81c294a9 100644 --- a/backends/python/src/management.rs +++ b/backends/python/src/management.rs @@ -36,9 +36,7 @@ impl BackendProcess { Pool::Cls => "cls", Pool::Mean => "mean", Pool::LastToken => "lasttoken", - Pool::Splade => { - return Err(BackendError::Start(format!("{pool:?} is not supported"))); - } + Pool::Splade => "splade", }; // Process args