Skip to content

add reranker model support for python backend #386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile-intel
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ RUN python -m pip install torch==2.4.0 torchvision torchaudio==2.4.0 --index-url
RUN cd backends/python/server && \
make install

FROM vault.habana.ai/gaudi-docker/1.17.1/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest AS hpu
FROM vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest AS hpu
ENV HUGGINGFACE_HUB_CACHE=/data \
PORT=80

Expand Down
21 changes: 21 additions & 0 deletions backends/grpc-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,25 @@ impl Client {
let response = self.stub.embed(request).await?.into_inner();
Ok(response.embeddings)
}

#[instrument(skip_all)]
pub async fn predict(
&mut self,
input_ids: Vec<u32>,
token_type_ids: Vec<u32>,
position_ids: Vec<u32>,
cu_seq_lengths: Vec<u32>,
max_length: u32,
) -> Result<Vec<Score>> {
let request = tonic::Request::new(EmbedRequest {
input_ids,
token_type_ids,
position_ids,
max_length,
cu_seq_lengths,
})
.inject_context();
let response = self.stub.predict(request).await?.into_inner();
Ok(response.scores)
}
}
10 changes: 10 additions & 0 deletions backends/proto/embed.proto
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ service EmbeddingService {
rpc Embed (EmbedRequest) returns (EmbedResponse);
/// Health check
rpc Health (HealthRequest) returns (HealthResponse);
/// Predict
rpc Predict (EmbedRequest) returns (PredictResponse);
}

message HealthRequest {}
Expand All @@ -28,3 +30,11 @@ message Embedding {
message EmbedResponse {
repeated Embedding embeddings = 1;
}

message Score {
repeated float values = 1;
}

message PredictResponse {
repeated Score scores = 1;
}
9 changes: 5 additions & 4 deletions backends/python/server/requirements-hpu.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
accelerate==0.33.0 ; python_version >= "3.9" and python_version < "3.13"
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
Expand Down Expand Up @@ -31,8 +32,8 @@ opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_versi
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
optimum-habana==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
optimum==1.21.4 ; python_version >= "3.9" and python_version < "3.13"
optimum-habana==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
optimum==1.23.3 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
pandas==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
Expand All @@ -46,8 +47,8 @@ six==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.43.4 ; python_version >= "3.9" and python_version < "3.13"
transformers[sentencepiece]==4.43.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
transformers[sentencepiece]==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13"
tzdata==2024.1 ; python_version >= "3.9" and python_version < "3.13"
Expand Down
27 changes: 18 additions & 9 deletions backends/python/server/text_embeddings_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from text_embeddings_server.models.model import Model
from text_embeddings_server.models.default_model import DefaultModel
from text_embeddings_server.models.classification_model import ClassificationModel
from text_embeddings_server.utils.device import get_device, use_ipex

__all__ = ["Model"]
Expand Down Expand Up @@ -43,18 +44,19 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
if config.model_type == "bert":
config: BertConfig
if (
device.type == "cuda"
use_ipex()
or device.type in ["cuda", "hpu"]
and config.position_embedding_type == "absolute"
and datatype in [torch.float16, torch.bfloat16]
and FLASH_ATTENTION
):
if pool != "cls":
raise ValueError("FlashBert only supports cls pooling")
return FlashBert(model_path, device, datatype) # type: ignore
if use_ipex() or device.type == "hpu":
return FlashBert(model_path, device, datatype) # type: ignore

return DefaultModel(model_path, device, datatype)
return DefaultModel(model_path, device, datatype, pool)
return FlashBert(model_path, device, datatype)
if config.architectures[0].endswith("Classification"):
return ClassificationModel(model_path, device, datatype)
else:
return DefaultModel(model_path, device, datatype, pool)
else:
if device.type == "hpu":
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
Expand All @@ -63,7 +65,14 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
)

adapt_transformers_to_gaudi()
model_handle = DefaultModel(model_path, device, datatype)
if config.architectures[0].endswith("Classification"):
model_handle = ClassificationModel(model_path, device, datatype)
else:
model_handle = DefaultModel(model_path, device, datatype, pool)
model_handle.model = wrap_in_hpu_graph(model_handle.model)
return model_handle
return DefaultModel(model_path, device, datatype)
elif use_ipex():
if config.architectures[0].endswith("Classification"):
return ClassificationModel(model_path, device, datatype)
else:
return DefaultModel(model_path, device, datatype, pool)
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import inspect
import torch

from pathlib import Path
from typing import Type, List
from transformers import AutoModelForSequenceClassification
from opentelemetry import trace

from text_embeddings_server.models import Model
from text_embeddings_server.models.types import PaddedBatch, Embedding, Score

tracer = trace.get_tracer(__name__)


class ClassificationModel(Model):
def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model = model.to(dtype).to(device)

self.hidden_size = model.config.hidden_size
self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None)
is not None
)
self.has_token_type_ids = (
inspect.signature(model.forward).parameters.get("token_type_ids", None)
is not None
)

super(ClassificationModel, self).__init__(
model=model, dtype=dtype, device=device
)

@property
def batch_type(self) -> Type[PaddedBatch]:
return PaddedBatch

@tracer.start_as_current_span("embed")
def embed(self, batch: PaddedBatch) -> List[Embedding]:
pass

@tracer.start_as_current_span("predict")
def predict(self, batch: PaddedBatch) -> List[Score]:
kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask}
if self.has_token_type_ids:
kwargs["token_type_ids"] = batch.token_type_ids
if self.has_position_ids:
kwargs["position_ids"] = batch.position_ids

output = self.model(**kwargs, return_dict=True)
all_scores = output.logits.tolist()
return [Score(values=scores) for scores in all_scores]
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sentence_transformers.models import Pooling

from text_embeddings_server.models import Model
from text_embeddings_server.models.types import PaddedBatch, Embedding
from text_embeddings_server.models.types import PaddedBatch, Embedding, Score

tracer = trace.get_tracer(__name__)

Expand Down Expand Up @@ -59,3 +59,7 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]:
)
for i in range(len(batch))
]

@tracer.start_as_current_span("predict")
def predict(self, batch: PaddedBatch) -> List[Score]:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from opentelemetry import trace

from text_embeddings_server.pb import embed_pb2
from text_embeddings_server.pb.embed_pb2 import Embedding
from text_embeddings_server.pb.embed_pb2 import Embedding, Score

tracer = trace.get_tracer(__name__)

Expand Down
7 changes: 7 additions & 0 deletions backends/python/server/text_embeddings_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ async def Embed(self, request, context):

return embed_pb2.EmbedResponse(embeddings=embeddings)

async def Predict(self, request, context):
batch = self.model.batch_type.from_pb(request, self.model.device)

scores = self.model.predict(batch)

return embed_pb2.PredictResponse(scores=scores)


def serve(
model_path: Path,
Expand Down
39 changes: 29 additions & 10 deletions backends/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use backend_grpc_client::Client;
use nohash_hasher::BuildNoHashHasher;
use std::collections::HashMap;
use text_embeddings_backend_core::{
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Predictions,
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions,
};
use tokio::runtime::Runtime;

Expand All @@ -25,11 +25,7 @@ impl PythonBackend {
otlp_service_name: String,
) -> Result<Self, BackendError> {
let pool = match model_type {
ModelType::Classifier => {
return Err(BackendError::Start(
"`classifier` model type is not supported".to_string(),
))
}
ModelType::Classifier => Pool::Cls,
ModelType::Embedding(pool) => pool,
};

Expand Down Expand Up @@ -105,9 +101,32 @@ impl Backend for PythonBackend {
Ok(embeddings)
}

fn predict(&self, _batch: Batch) -> Result<Predictions, BackendError> {
Err(BackendError::Inference(
"`predict` is not implemented".to_string(),
))
fn predict(&self, batch: Batch) -> Result<Predictions, BackendError> {
if !batch.raw_indices.is_empty() {
return Err(BackendError::Inference(
"raw embeddings are not supported for the Python backend.".to_string(),
));
}
let batch_size = batch.len();
let results = self
.tokio_runtime
.block_on(self.backend_client.clone().predict(
batch.input_ids,
batch.token_type_ids,
batch.position_ids,
batch.cumulative_seq_lengths,
batch.max_length,
))
.map_err(|err| BackendError::Inference(err.to_string()))?;
let raw_results: Vec<Vec<f32>> = results.into_iter().map(|r| r.values).collect();

let mut predictions =
HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());

for (i, r) in raw_results.into_iter().enumerate() {
predictions.insert(i, r);
}

Ok(predictions)
}
}