From 37ba43db40cef8ec9db888a2eb53a8889cf11f2f Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Fri, 14 Mar 2025 06:35:50 -0400 Subject: [PATCH 1/3] fix bug for `MaskedLanguageModel` class` Signed-off-by: Liu, Kaixuan --- .../text_embeddings_server/models/__init__.py | 2 +- .../text_embeddings_server/models/masked_model.py | 15 +++++++++++++-- .../server/text_embeddings_server/models/types.py | 1 + 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index 72ec500d..d5129f96 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -98,7 +98,7 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): trust_remote=TRUST_REMOTE_CODE, ) elif config.architectures[0].endswith("ForMaskedLM"): - return MaskedLanguageModel( + model_handle = MaskedLanguageModel( model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE ) else: diff --git a/backends/python/server/text_embeddings_server/models/masked_model.py b/backends/python/server/text_embeddings_server/models/masked_model.py index ea622255..9c114c4a 100644 --- a/backends/python/server/text_embeddings_server/models/masked_model.py +++ b/backends/python/server/text_embeddings_server/models/masked_model.py @@ -36,7 +36,16 @@ def __init__( self.pooling = SpladePooling() else: self.pooling = DefaultPooling(self.hidden_size, pooling_mode=pool) - + position_offset = 0 + model_type = model.config.model_type + if model_type in ["xlm-roberta", "camembert", "roberta"]: + position_offset = model.config.pad_token_id + 1 + if hasattr(model.config, "max_seq_length"): + self.max_input_length = model.config.max_seq_length + else: + self.max_input_length = ( + model.config.max_position_embeddings - position_offset + ) self.has_position_ids = ( inspect.signature(model.forward).parameters.get("position_ids", None) is not None @@ -65,7 +74,9 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]: embedding = self.pooling.forward(output, batch.attention_mask) cpu_results = embedding.view(-1).tolist() - step_size = embedding.shape[-1] + step_size = ( + embedding.shape[-1] if self.pooling_mode == "splade" else self.hidden_size + ) return [ Embedding(values=cpu_results[i * step_size : (i + 1) * step_size]) for i in range(len(batch)) diff --git a/backends/python/server/text_embeddings_server/models/types.py b/backends/python/server/text_embeddings_server/models/types.py index bd22cdee..4f2cfa47 100644 --- a/backends/python/server/text_embeddings_server/models/types.py +++ b/backends/python/server/text_embeddings_server/models/types.py @@ -1,4 +1,5 @@ import os +import math import torch from abc import ABC, abstractmethod From 3cd632ae7371653d3d28a3e337c7edc2620ce774 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Fri, 14 Mar 2025 12:56:18 -0400 Subject: [PATCH 2/3] avoid using `AutoModelForMaskedLM` to load model if not using splade pooling Signed-off-by: Liu, Kaixuan --- .../text_embeddings_server/models/__init__.py | 6 +++--- .../models/default_model.py | 10 +++------- .../text_embeddings_server/models/masked_model.py | 15 +++------------ 3 files changed, 9 insertions(+), 22 deletions(-) diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index d5129f96..48cba2fb 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -97,7 +97,7 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): datatype, trust_remote=TRUST_REMOTE_CODE, ) - elif config.architectures[0].endswith("ForMaskedLM"): + elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade": model_handle = MaskedLanguageModel( model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE ) @@ -119,9 +119,9 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): datatype, trust_remote=TRUST_REMOTE_CODE, ) - elif config.architectures[0].endswith("ForMaskedLM"): + elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade": return MaskedLanguageModel( - model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE + model_path, device, datatype, trust_remote=TRUST_REMOTE_CODE ) else: return DefaultModel( diff --git a/backends/python/server/text_embeddings_server/models/default_model.py b/backends/python/server/text_embeddings_server/models/default_model.py index 896095d3..f5c569fd 100644 --- a/backends/python/server/text_embeddings_server/models/default_model.py +++ b/backends/python/server/text_embeddings_server/models/default_model.py @@ -5,7 +5,7 @@ from typing import Type, List from transformers import AutoModel from opentelemetry import trace -from sentence_transformers.models import Pooling +from text_embeddings_server.models.pooling import DefaultPooling from text_embeddings_server.models import Model from text_embeddings_server.models.types import PaddedBatch, Embedding, Score @@ -28,7 +28,7 @@ def __init__( .to(device) ) self.hidden_size = model.config.hidden_size - self.pooling = Pooling(self.hidden_size, pooling_mode=pool) + self.pooling = DefaultPooling(self.hidden_size, pooling_mode=pool) position_offset = 0 model_type = model.config.model_type @@ -65,11 +65,7 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]: kwargs["position_ids"] = batch.position_ids output = self.model(**kwargs) - pooling_features = { - "token_embeddings": output[0], - "attention_mask": batch.attention_mask, - } - embedding = self.pooling.forward(pooling_features)["sentence_embedding"] + embedding = self.pooling.forward(output, batch.attention_mask) cpu_results = embedding.view(-1).tolist() diff --git a/backends/python/server/text_embeddings_server/models/masked_model.py b/backends/python/server/text_embeddings_server/models/masked_model.py index 9c114c4a..6a53a194 100644 --- a/backends/python/server/text_embeddings_server/models/masked_model.py +++ b/backends/python/server/text_embeddings_server/models/masked_model.py @@ -8,7 +8,7 @@ from text_embeddings_server.models import Model from text_embeddings_server.models.types import PaddedBatch, Embedding, Score -from text_embeddings_server.models.pooling import DefaultPooling, SpladePooling +from text_embeddings_server.models.pooling import SpladePooling tracer = trace.get_tracer(__name__) @@ -19,7 +19,6 @@ def __init__( model_path: Path, device: torch.device, dtype: torch.dtype, - pool: str, trust_remote: bool = False, ): model = ( @@ -29,13 +28,7 @@ def __init__( .to(dtype) .to(device) ) - self.hidden_size = model.config.hidden_size - self.vocab_size = model.config.vocab_size - self.pooling_mode = pool - if pool == "splade": - self.pooling = SpladePooling() - else: - self.pooling = DefaultPooling(self.hidden_size, pooling_mode=pool) + self.pooling = SpladePooling() position_offset = 0 model_type = model.config.model_type if model_type in ["xlm-roberta", "camembert", "roberta"]: @@ -74,9 +67,7 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]: embedding = self.pooling.forward(output, batch.attention_mask) cpu_results = embedding.view(-1).tolist() - step_size = ( - embedding.shape[-1] if self.pooling_mode == "splade" else self.hidden_size - ) + step_size = embedding.shape[-1] return [ Embedding(values=cpu_results[i * step_size : (i + 1) * step_size]) for i in range(len(batch)) From 92577ce7ceb935fa4d212d5dea6ae15aeeaee083 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Fri, 14 Mar 2025 13:20:24 -0400 Subject: [PATCH 3/3] small fix Signed-off-by: Liu, Kaixuan --- .../server/text_embeddings_server/models/__init__.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index 48cba2fb..15097be4 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -54,12 +54,11 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): and FLASH_ATTENTION ): if pool != "cls": - if config.architectures[0].endswith("ForMaskedLM"): + if config.architectures[0].endswith("ForMaskedLM") and pool == "splade": return MaskedLanguageModel( model_path, device, datatype, - pool, trust_remote=TRUST_REMOTE_CODE, ) return DefaultModel( @@ -70,9 +69,9 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): return ClassificationModel( model_path, device, datatype, trust_remote=TRUST_REMOTE_CODE ) - elif config.architectures[0].endswith("ForMaskedLM"): + elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade": return MaskedLanguageModel( - model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE + model_path, device, datatype, trust_remote=TRUST_REMOTE_CODE ) else: return DefaultModel( @@ -99,7 +98,7 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): ) elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade": model_handle = MaskedLanguageModel( - model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE + model_path, device, datatype, trust_remote=TRUST_REMOTE_CODE ) else: model_handle = DefaultModel(