Skip to content

Commit bb2093b

Browse files
committed
Enable splade embeddings for Python backend
Signed-off-by: Daniel Huang <[email protected]>
1 parent 4ac772e commit bb2093b

File tree

4 files changed

+113
-3
lines changed

4 files changed

+113
-3
lines changed

backends/python/server/text_embeddings_server/models/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from transformers.models.bert import BertConfig
99

1010
from text_embeddings_server.models.model import Model
11+
from text_embeddings_server.models.masked_model import MaskedLanguageModel
1112
from text_embeddings_server.models.default_model import DefaultModel
1213
from text_embeddings_server.models.classification_model import ClassificationModel
1314
from text_embeddings_server.utils.device import get_device, use_ipex
@@ -53,6 +54,8 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
5354
and FLASH_ATTENTION
5455
):
5556
if pool != "cls":
57+
if config.architectures[0].endswith("ForMaskedLM"):
58+
return MaskedLanguageModel(model_path, device, datatype, pool)
5659
return DefaultModel(
5760
model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE
5861
)
@@ -61,6 +64,8 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
6164
return ClassificationModel(
6265
model_path, device, datatype, trust_remote=TRUST_REMOTE_CODE
6366
)
67+
elif config.architectures[0].endswith("ForMaskedLM"):
68+
return MaskedLanguageModel(model_path, device, datatype, pool)
6469
else:
6570
return DefaultModel(
6671
model_path,
@@ -84,6 +89,8 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
8489
datatype,
8590
trust_remote=TRUST_REMOTE_CODE,
8691
)
92+
elif config.architectures[0].endswith("ForMaskedLM"):
93+
return MaskedLanguageModel(model_path, device, datatype, pool)
8794
else:
8895
model_handle = DefaultModel(
8996
model_path,
@@ -102,6 +109,8 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
102109
datatype,
103110
trust_remote=TRUST_REMOTE_CODE,
104111
)
112+
elif config.architectures[0].endswith("ForMaskedLM"):
113+
return MaskedLanguageModel(model_path, device, datatype, pool)
105114
else:
106115
return DefaultModel(
107116
model_path,
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import inspect
2+
import torch
3+
4+
from pathlib import Path
5+
from typing import Type, List
6+
from transformers import AutoModelForMaskedLM
7+
from opentelemetry import trace
8+
9+
from text_embeddings_server.models import Model
10+
from text_embeddings_server.models.types import PaddedBatch, Embedding, Score
11+
from text_embeddings_server.models.pooling import DefaultPooling, SpladePooling
12+
13+
tracer = trace.get_tracer(__name__)
14+
15+
16+
class MaskedLanguageModel(Model):
17+
def __init__(
18+
self, model_path: Path, device: torch.device, dtype: torch.dtype, pool: str
19+
):
20+
model = AutoModelForMaskedLM.from_pretrained(model_path).to(dtype).to(device)
21+
self.hidden_size = model.config.hidden_size
22+
self.vocab_size = model.config.vocab_size
23+
self.pooling_mode = pool
24+
if pool == "splade":
25+
self.pooling = SpladePooling()
26+
else:
27+
self.pooling = DefaultPooling(self.hidden_size, pooling_mode=pool)
28+
29+
self.has_position_ids = (
30+
inspect.signature(model.forward).parameters.get("position_ids", None)
31+
is not None
32+
)
33+
self.has_token_type_ids = (
34+
inspect.signature(model.forward).parameters.get("token_type_ids", None)
35+
is not None
36+
)
37+
38+
super(MaskedLanguageModel, self).__init__(model=model, dtype=dtype, device=device)
39+
40+
@property
41+
def batch_type(self) -> Type[PaddedBatch]:
42+
return PaddedBatch
43+
44+
@tracer.start_as_current_span("embed")
45+
def embed(self, batch: PaddedBatch) -> List[Embedding]:
46+
kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask}
47+
if self.has_token_type_ids:
48+
kwargs["token_type_ids"] = batch.token_type_ids
49+
if self.has_position_ids:
50+
kwargs["position_ids"] = batch.position_ids
51+
output = self.model(**kwargs)
52+
embedding = self.pooling.forward(output, batch.attention_mask)
53+
cpu_results = embedding.view(-1).tolist()
54+
55+
step_size = embedding.shape[-1]
56+
return [
57+
Embedding(values=cpu_results[i * step_size : (i + 1) * step_size])
58+
for i in range(len(batch))
59+
]
60+
61+
@tracer.start_as_current_span("predict")
62+
def predict(self, batch: PaddedBatch) -> List[Score]:
63+
pass
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from abc import ABC, abstractmethod
2+
3+
import torch
4+
from opentelemetry import trace
5+
from sentence_transformers.models import Pooling
6+
from torch import Tensor
7+
8+
tracer = trace.get_tracer(__name__)
9+
10+
11+
class _Pooling(ABC):
12+
@abstractmethod
13+
def forward(self, model_output, attention_mask) -> Tensor:
14+
pass
15+
16+
17+
class DefaultPooling(_Pooling):
18+
def __init__(self, hidden_size, pooling_mode) -> None:
19+
assert (
20+
pooling_mode != "splade"
21+
), "Splade pooling is not supported for DefaultPooling"
22+
self.pooling = Pooling(hidden_size, pooling_mode=pooling_mode)
23+
24+
@tracer.start_as_current_span("pooling")
25+
def forward(self, model_output, attention_mask) -> Tensor:
26+
pooling_features = {
27+
"token_embeddings": model_output[0],
28+
"attention_mask": attention_mask,
29+
}
30+
return self.pooling.forward(pooling_features)["sentence_embedding"]
31+
32+
33+
class SpladePooling(_Pooling):
34+
@tracer.start_as_current_span("pooling")
35+
def forward(self, model_output, attention_mask) -> Tensor:
36+
# Implement Splade pooling
37+
hidden_states = torch.relu(model_output[0])
38+
hidden_states = (1 + hidden_states).log()
39+
hidden_states = torch.mul(hidden_states, attention_mask.unsqueeze(-1))
40+
return hidden_states.max(dim=1).values

backends/python/src/management.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@ impl BackendProcess {
3636
Pool::Cls => "cls",
3737
Pool::Mean => "mean",
3838
Pool::LastToken => "lasttoken",
39-
Pool::Splade => {
40-
return Err(BackendError::Start(format!("{pool:?} is not supported")));
41-
}
39+
Pool::Splade => "splade",
4240
};
4341

4442
// Process args

0 commit comments

Comments
 (0)