Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
042dc8f
Enable CPU device for python backend
yuanwu2017 Apr 22, 2024
5bd2008
Add the flash attention for cpu optimization.
yuanwu2017 Apr 23, 2024
e0605d2
Merge branch 'huggingface:main' into ipex
yuanwu2017 Jun 13, 2024
51a4af4
Refine the patch
yuanwu2017 Jun 12, 2024
d512003
Use the latest ipex
yuanwu2017 Jun 23, 2024
c49cb7b
Merge branch 'huggingface:main' into ipex
yuanwu2017 Jun 23, 2024
8e499ea
Merge pull request #1 from kaixuanliu/ipex
yuanwu2017 Jul 17, 2024
70e154e
add python backend support for xlm-roberta type model
kaixuanliu Jul 17, 2024
35d3fa0
refine code
kaixuanliu Jul 18, 2024
af2e942
nice code
kaixuanliu Jul 18, 2024
e7a7388
Merge pull request #2 from kaixuanliu/ipex
yuanwu2017 Jul 18, 2024
4d94445
Merge branch 'main' into ipex
yuanwu2017 Jul 30, 2024
f61b8bd
Fix build error
yuanwu2017 Jul 30, 2024
1c2e6ab
add XPU and HPU support
kaixuanliu Aug 15, 2024
e161e83
change build-arg to `cpu` instead of `ipex`
kaixuanliu Aug 15, 2024
fc979a9
nice code
kaixuanliu Aug 15, 2024
09a0d3f
Merge pull request #3 from kaixuanliu/ipex
yuanwu2017 Aug 15, 2024
8d8148c
add import ipex
kaixuanliu Aug 22, 2024
081ab41
Merge pull request #4 from kaixuanliu/ipex
yuanwu2017 Aug 22, 2024
cbc3ee2
add hpu flashBert support
kaixuanliu Aug 29, 2024
f7d1e1b
Merge pull request #5 from kaixuanliu/ipex
yuanwu2017 Aug 29, 2024
fd5ad34
nice code
kaixuanliu Aug 29, 2024
8cb7fa0
update version
kaixuanliu Aug 29, 2024
8b2d1bf
Merge pull request #6 from kaixuanliu/ipex
yuanwu2017 Aug 29, 2024
0f09872
Cargo fmt and ruff format
pi314ever Aug 30, 2024
7684eeb
Merge pull request #8 from pi314ever/dh/ipex-formatting
yuanwu2017 Aug 30, 2024
5dcd07b
Unused imports and better imports
pi314ever Sep 5, 2024
984532d
Clippy fixes
pi314ever Sep 5, 2024
66f7775
End of File fixes
pi314ever Sep 5, 2024
542604c
Nicer importlib imports
pi314ever Sep 10, 2024
267c559
Compact code
pi314ever Sep 10, 2024
c70252c
Merge pull request #9 from pi314ever/dh/ipex-minor-cleanup
yuanwu2017 Sep 19, 2024
978c1fe
fix bug
kaixuanliu Sep 19, 2024
c6aba0d
Merge pull request #10 from kaixuanliu/ipex
yuanwu2017 Sep 20, 2024
e75ad3f
upgrade xpu-ipex to 2.3.110
kaixuanliu Sep 24, 2024
04744eb
Merge pull request #11 from kaixuanliu/ipex
yuanwu2017 Sep 25, 2024
9079081
Merge branch 'main' into ipex
yuanwu2017 Dec 12, 2024
8cd73a2
fix conflict env
kaixuanliu Dec 16, 2024
5b19baf
Merge pull request #12 from kaixuanliu/ipex-conflict-resolve
yuanwu2017 Dec 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ tracing = "0.1"
serde = { version = "1.0", features = ["serde_derive"] }
serde_json = "1.0"
thiserror = "1.0"
rand = "0.8"


[patch.crates-io]
Expand Down
146 changes: 146 additions & 0 deletions Dockerfile-intel
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
ARG PLATFORM=cpu
FROM lukemathwalker/cargo-chef:latest-rust-1.75-bookworm AS chef
WORKDIR /usr/src
ENV SCCACHE=0.5.4
ENV RUSTC_WRAPPER=/usr/local/bin/sccache

# Download and configure sccache
RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \
chmod +x /usr/local/bin/sccache

FROM chef AS planner

COPY backends backends
COPY core core
COPY router router
COPY Cargo.toml ./
COPY Cargo.lock ./

RUN cargo chef prepare --recipe-path recipe.json

FROM chef AS builder

ARG GIT_SHA
ARG DOCKER_LABEL

# sccache specific variables
ARG ACTIONS_CACHE_URL
ARG ACTIONS_RUNTIME_TOKEN
ARG SCCACHE_GHA_ENABLED

COPY --from=planner /usr/src/recipe.json recipe.json

RUN cargo chef cook --release --features python --no-default-features --recipe-path recipe.json && sccache -s

COPY backends backends
COPY core core
COPY router router
COPY Cargo.toml ./
COPY Cargo.lock ./

RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
rm -f $PROTOC_ZIP

FROM builder as http-builder

RUN cargo build --release --bin text-embeddings-router -F python -F http --no-default-features && sccache -s

FROM builder as grpc-builder

COPY proto proto

RUN cargo build --release --bin text-embeddings-router -F grpc -F python --no-default-features && sccache -s

FROM intel/intel-optimized-pytorch:2.4.0-pip-base AS cpu
ENV HUGGINGFACE_HUB_CACHE=/data \
PORT=80

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
git \
cmake \
ninja-build \
python3-dev &&\
rm -rf /var/lib/apt/lists/*

WORKDIR /usr/src
COPY backends backends
COPY backends/python/server/text_embeddings_server/models/__init__.py backends/python/server/text_embeddings_server/models/__init__.py
COPY backends/python/server/pyproject.toml backends/python/server/pyproject.toml
COPY backends/python/server/requirements-intel.txt backends/python/server/requirements.txt

RUN python -m pip install torch==2.4.0 torchvision torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cpu

RUN cd backends/python/server && \
make install

FROM vault.habana.ai/gaudi-docker/1.17.1/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest AS hpu
ENV HUGGINGFACE_HUB_CACHE=/data \
PORT=80

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
git \
cmake \
ninja-build \
python3-dev &&\
rm -rf /var/lib/apt/lists/*

WORKDIR /usr/src
COPY backends backends
COPY backends/python/server/text_embeddings_server/models/__init__.py backends/python/server/text_embeddings_server/models/__init__.py
COPY backends/python/server/pyproject.toml backends/python/server/pyproject.toml
COPY backends/python/server/requirements-hpu.txt backends/python/server/requirements.txt

RUN cd backends/python/server && \
make install

FROM intel/intel-extension-for-pytorch:2.3.110-xpu AS xpu

ENV HUGGINGFACE_HUB_CACHE=/data \
PORT=80
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb

RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null

RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list

RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build pciutils
WORKDIR /usr/src
RUN pip install torch==2.3.1+cxx11.abi torchvision==0.18.1+cxx11.abi torchaudio==2.3.1+cxx11.abi intel-extension-for-pytorch==2.3.110+xpu oneccl_bind_pt==2.3.100+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --no-cache-dir

ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest
ENV I_MPI_ROOT=/opt/intel/oneapi/mpi/latest
ENV FI_PROVIDER_PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib/prov:/usr/lib/x86_64-linux-gnu/libfabric
ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mkl/latest/lib/:/opt/intel/oneapi/compiler/latest/lib
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:
ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
ENV CCL_ZE_IPC_EXCHANGE=sockets
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include

COPY backends backends
COPY backends/python/server/text_embeddings_server/models/__init__.py backends/python/server/text_embeddings_server/models/__init__.py
COPY backends/python/server/pyproject.toml backends/python/server/pyproject.toml
COPY backends/python/server/requirements-intel.txt backends/python/server/requirements.txt
RUN cd backends/python/server && \
make install

FROM ${PLATFORM} AS grpc

COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router

ENTRYPOINT ["text-embeddings-router"]
CMD ["--json-output"]

FROM ${PLATFORM}

COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router

ENTRYPOINT ["text-embeddings-router"]
CMD ["--json-output"]
1 change: 1 addition & 0 deletions backends/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ text-embeddings-backend-candle = { path = "candle", optional = true }
text-embeddings-backend-ort = { path = "ort", optional = true }
tokio = { workspace = true }
tracing = { workspace = true }
rand = { workspace = true }

[features]
clap = ["dep:clap", "text-embeddings-backend-core/clap"]
Expand Down
3 changes: 1 addition & 2 deletions backends/python/server/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ gen-server:

install: gen-server
pip install pip --upgrade
pip install torch==2.5.1
pip install -r requirements.txt
pip install --no-deps -r requirements.txt
pip install -e .

run-dev:
Expand Down
1 change: 0 additions & 1 deletion backends/python/server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ opentelemetry-api = "^1.25.0"
opentelemetry-exporter-otlp = "^1.25.0"
opentelemetry-instrumentation-grpc = "^0.46b0"
sentence-transformers = "^3.3.1"
torch = "^2.5.1"

[tool.poetry.extras]

Expand Down
61 changes: 61 additions & 0 deletions backends/python/server/requirements-hpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.13.4 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13"
fsspec[http]==2024.2.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.24.5 ; python_version >= "3.9" and python_version < "3.13"
humanfriendly==10.0 ; python_version >= "3.9" and python_version < "3.13"
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
jinja2==3.1.3 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markupsafe==2.1.5 ; python_version >= "3.9" and python_version < "3.13"
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
optimum-habana==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
optimum==1.21.4 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
pandas==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.4.16 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.1 ; python_version >= "3.9" and python_version < "3.13"
setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13"
six==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.43.4 ; python_version >= "3.9" and python_version < "3.13"
transformers[sentencepiece]==4.43.4 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13"
tzdata==2024.1 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
xxhash==3.4.1 ; python_version >= "3.9" and python_version < "3.13"
yarl==1.9.4 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.18.1 ; python_version >= "3.9" and python_version < "3.13"
pyrsistent==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
einops==0.8.0 ; python_version >= "3.9" and python_version < "3.13"
44 changes: 44 additions & 0 deletions backends/python/server/requirements-intel.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2023.9.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.3 ; python_version >= "3.9" and python_version < "3.13"
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13"
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
networkx==3.1 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.1 ; python_version >= "3.9" and python_version < "3.13"
setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.40.0 ; python_version >= "3.9" and python_version < "3.13"
pyrsistent==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
1 change: 0 additions & 1 deletion backends/python/server/text_embeddings_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def serve(

# Downgrade enum into str for easier management later on
dtype = None if dtype is None else dtype.value

server.serve(model_path, dtype, uds_path, pool)


Expand Down
36 changes: 23 additions & 13 deletions backends/python/server/text_embeddings_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from text_embeddings_server.models.model import Model
from text_embeddings_server.models.default_model import DefaultModel
from text_embeddings_server.utils.device import get_device, use_ipex

__all__ = ["Model"]

Expand All @@ -27,33 +28,42 @@

def get_model(model_path: Path, dtype: Optional[str], pool: str):
if dtype == "float32":
dtype = torch.float32
datatype = torch.float32
elif dtype == "float16":
dtype = torch.float16
datatype = torch.float16
elif dtype == "bfloat16":
dtype = torch.bfloat16
datatype = torch.bfloat16
else:
raise RuntimeError(f"Unknown dtype {dtype}")

if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
device = get_device()
logger.info(f"backend device: {device}")

config = AutoConfig.from_pretrained(model_path)

if config.model_type == "bert":
config: BertConfig
if (
device.type == "cuda"
and config.position_embedding_type == "absolute"
and dtype in [torch.float16, torch.bfloat16]
and datatype in [torch.float16, torch.bfloat16]
and FLASH_ATTENTION
):
if pool != "cls":
raise ValueError("FlashBert only supports cls pooling")
return FlashBert(model_path, device, dtype)
else:
return DefaultModel(model_path, device, dtype, pool)
return FlashBert(model_path, device, datatype) # type: ignore
if use_ipex() or device.type == "hpu":
return FlashBert(model_path, device, datatype) # type: ignore

return DefaultModel(model_path, device, datatype)
else:
if device.type == "hpu":
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
from optimum.habana.transformers.modeling_utils import (
adapt_transformers_to_gaudi,
)

raise NotImplementedError
adapt_transformers_to_gaudi()
model_handle = DefaultModel(model_path, device, datatype)
model_handle.model = wrap_in_hpu_graph(model_handle.model)
return model_handle
return DefaultModel(model_path, device, datatype)
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]:
kwargs["token_type_ids"] = batch.token_type_ids
if self.has_position_ids:
kwargs["position_ids"] = batch.position_ids

output = self.model(**kwargs)

pooling_features = {
Expand Down
Loading