Skip to content

feat: support multiple backends at the same time #440

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,560 changes: 921 additions & 639 deletions Cargo.lock

Large diffs are not rendered by default.

37 changes: 33 additions & 4 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,22 @@ ARG ACTIONS_CACHE_URL
ARG ACTIONS_RUNTIME_TOKEN
ARG SCCACHE_GHA_ENABLED

RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \
echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | \
tee /etc/apt/sources.list.d/oneAPI.list

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
intel-oneapi-mkl-devel=2024.0.0-49656 \
build-essential \
&& rm -rf /var/lib/apt/lists/*

RUN echo "int mkl_serv_intel_cpu_true() {return 1;}" > fakeintel.c && \
gcc -shared -fPIC -o libfakeintel.so fakeintel.c

COPY --from=planner /usr/src/recipe.json recipe.json

RUN cargo chef cook --release --features ort --no-default-features --recipe-path recipe.json && sccache -s
RUN cargo chef cook --release --features ort --features candle --features mkl-dynamic --no-default-features --recipe-path recipe.json && sccache -s

COPY backends backends
COPY core core
Expand All @@ -40,7 +53,7 @@ COPY Cargo.lock ./

FROM builder AS http-builder

RUN cargo build --release --bin text-embeddings-router -F ort -F http --no-default-features && sccache -s
RUN cargo build --release --bin text-embeddings-router -F ort -F candle -F mkl-dynamic -F http --no-default-features && sccache -s

FROM builder AS grpc-builder

Expand All @@ -52,19 +65,35 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \

COPY proto proto

RUN cargo build --release --bin text-embeddings-router -F grpc -F ort --no-default-features && sccache -s
RUN cargo build --release --bin text-embeddings-router -F grpc -F ort -F candle -F mkl-dynamic --no-default-features && sccache -s

FROM debian:bookworm-slim AS base

ENV HUGGINGFACE_HUB_CACHE=/data \
PORT=80
PORT=80 \
MKL_ENABLE_INSTRUCTIONS=AVX512_E4 \
RAYON_NUM_THREADS=8 \
LD_PRELOAD=/usr/local/libfakeintel.so \
LD_LIBRARY_PATH=/usr/local/lib

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
libomp-dev \
ca-certificates \
libssl-dev \
curl \
&& rm -rf /var/lib/apt/lists/*

# Copy a lot of the Intel shared objects because of the mkl_serv_intel_cpu_true patch...
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_intel_lp64.so.2 /usr/local/lib/libmkl_intel_lp64.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_intel_thread.so.2 /usr/local/lib/libmkl_intel_thread.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_core.so.2 /usr/local/lib/libmkl_core.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_def.so.2 /usr/local/lib/libmkl_vml_def.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_def.so.2 /usr/local/lib/libmkl_def.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_avx2.so.2 /usr/local/lib/libmkl_vml_avx2.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_avx512.so.2 /usr/local/lib/libmkl_vml_avx512.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx2.so.2 /usr/local/lib/libmkl_avx2.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx512.so.2 /usr/local/lib/libmkl_avx512.so.2
COPY --from=builder /usr/src/libfakeintel.so /usr/local/libfakeintel.so

FROM base AS grpc

Expand Down
4 changes: 2 additions & 2 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use candle_nn::VarBuilder;
use nohash_hasher::BuildNoHashHasher;
use serde::Deserialize;
use std::collections::HashMap;
use std::path::PathBuf;
use std::path::Path;
use text_embeddings_backend_core::{
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Predictions,
};
Expand Down Expand Up @@ -69,7 +69,7 @@ pub struct CandleBackend {

impl CandleBackend {
pub fn new(
model_path: PathBuf,
model_path: &Path,
dtype: String,
model_type: ModelType,
) -> Result<Self, BackendError> {
Expand Down
2 changes: 2 additions & 0 deletions backends/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,6 @@ pub enum BackendError {
Inference(String),
#[error("Backend is unhealthy")]
Unhealthy,
#[error("Weights not found: {0}")]
WeightsNotFound(String),
}
2 changes: 1 addition & 1 deletion backends/ort/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ homepage.workspace = true
[dependencies]
anyhow = { workspace = true }
nohash-hasher = { workspace = true }
ndarray = "0.15.6"
ndarray = "0.16.1"
num_cpus = { workspace = true }
ort = { version = "2.0.0-rc.4", default-features = false, features = ["download-binaries", "half", "onednn", "ndarray"] }
text-embeddings-backend-core = { path = "../core" }
Expand Down
9 changes: 5 additions & 4 deletions backends/ort/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use ndarray::{s, Axis};
use nohash_hasher::BuildNoHashHasher;
use ort::{GraphOptimizationLevel, Session};
use ort::session::{builder::GraphOptimizationLevel, Session};
use std::collections::HashMap;
use std::ops::{Div, Mul};
use std::path::PathBuf;
use std::path::Path;
use text_embeddings_backend_core::{
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions,
};
Expand All @@ -16,12 +16,12 @@ pub struct OrtBackend {

impl OrtBackend {
pub fn new(
model_path: PathBuf,
model_path: &Path,
dtype: String,
model_type: ModelType,
) -> Result<Self, BackendError> {
// Check dtype
if &dtype == "float32" {
if dtype == "float32" {
} else {
return Err(BackendError::Start(format!(
"DType {dtype} is not supported"
Expand Down Expand Up @@ -246,6 +246,7 @@ impl Backend for OrtBackend {
if has_raw_requests {
// Reshape outputs
let s = outputs.shape().to_vec();
#[allow(deprecated)]
let outputs = outputs.into_shape((s[0] * s[1], s[2])).e()?;

// We need to remove the padding tokens only if batch_size > 1 and there are some
Expand Down
145 changes: 89 additions & 56 deletions backends/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ pub struct Backend {
}

impl Backend {
pub fn new(
pub async fn new(
model_path: PathBuf,
api_repo: Option<ApiRepo>,
dtype: DType,
model_type: ModelType,
uds_path: String,
Expand All @@ -49,12 +50,14 @@ impl Backend {

let backend = init_backend(
model_path,
api_repo,
dtype,
model_type.clone(),
uds_path,
otlp_endpoint,
otlp_service_name,
)?;
)
.await?;
let padded_model = backend.is_padded();
let max_batch_size = backend.max_batch_size();

Expand Down Expand Up @@ -193,48 +196,102 @@ impl Backend {
}

#[allow(unused)]
fn init_backend(
async fn init_backend(
model_path: PathBuf,
api_repo: Option<ApiRepo>,
dtype: DType,
model_type: ModelType,
uds_path: String,
otlp_endpoint: Option<String>,
otlp_service_name: String,
) -> Result<Box<dyn CoreBackend + Send>, BackendError> {
let mut backend_start_failed = false;

if cfg!(feature = "ort") {
#[cfg(feature = "ort")]
{
if let Some(api_repo) = api_repo.as_ref() {
let start = std::time::Instant::now();
download_onnx(api_repo)
.await
.map_err(|err| BackendError::WeightsNotFound(err.to_string()));
tracing::info!("Model ONNX weights downloaded in {:?}", start.elapsed());
}

let backend = OrtBackend::new(&model_path, dtype.to_string(), model_type.clone());
match backend {
Ok(b) => return Ok(Box::new(b)),
Err(err) => {
tracing::error!("Could not start ORT backend: {err}");
backend_start_failed = true;
}
}
}
}

if let Some(api_repo) = api_repo.as_ref() {
if cfg!(feature = "python") || cfg!(feature = "candle") {
let start = std::time::Instant::now();
if download_safetensors(api_repo).await.is_err() {
tracing::warn!("safetensors weights not found. Using `pytorch_model.bin` instead. Model loading will be significantly slower.");
tracing::info!("Downloading `pytorch_model.bin`");
api_repo
.get("pytorch_model.bin")
.await
.map_err(|err| BackendError::WeightsNotFound(err.to_string()))?;
}

tracing::info!("Model weights downloaded in {:?}", start.elapsed());
}
}

if cfg!(feature = "candle") {
#[cfg(feature = "candle")]
return Ok(Box::new(CandleBackend::new(
model_path,
dtype.to_string(),
model_type,
)?));
} else if cfg!(feature = "python") {
{
let backend = CandleBackend::new(&model_path, dtype.to_string(), model_type.clone());
match backend {
Ok(b) => return Ok(Box::new(b)),
Err(err) => {
tracing::error!("Could not start Candle backend: {err}");
backend_start_failed = true;
}
}
}
}

if cfg!(feature = "python") {
#[cfg(feature = "python")]
{
return Ok(Box::new(
std::thread::spawn(move || {
PythonBackend::new(
model_path.to_str().unwrap().to_string(),
dtype.to_string(),
model_type,
uds_path,
otlp_endpoint,
otlp_service_name,
)
})
.join()
.expect("Python Backend management thread failed")?,
));
let backend = std::thread::spawn(move || {
PythonBackend::new(
model_path.to_str().unwrap().to_string(),
dtype.to_string(),
model_type,
uds_path,
otlp_endpoint,
otlp_service_name,
)
})
.join()
.expect("Python Backend management thread failed");

match backend {
Ok(b) => return Ok(Box::new(b)),
Err(err) => {
tracing::error!("Could not start Python backend: {err}");
backend_start_failed = true;
}
}
}
} else if cfg!(feature = "ort") {
#[cfg(feature = "ort")]
return Ok(Box::new(OrtBackend::new(
model_path,
dtype.to_string(),
model_type,
)?));
}
Err(BackendError::NoBackend)

if backend_start_failed {
Err(BackendError::Start(
"Could not start a suitable backend".to_string(),
))
} else {
Err(BackendError::NoBackend)
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -298,31 +355,6 @@ enum BackendCommand {
),
}

pub async fn download_weights(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {
let model_files = if cfg!(feature = "python") || cfg!(feature = "candle") {
match download_safetensors(api).await {
Ok(p) => p,
Err(_) => {
tracing::warn!("safetensors weights not found. Using `pytorch_model.bin` instead. Model loading will be significantly slower.");
tracing::info!("Downloading `pytorch_model.bin`");
let p = api.get("pytorch_model.bin").await?;
vec![p]
}
}
} else if cfg!(feature = "ort") {
match download_onnx(api).await {
Ok(p) => p,
Err(err) => {
panic!("failed to download `model.onnx` or `model.onnx_data`. Check the onnx file exists in the repository. {err}");
}
}
} else {
unreachable!()
};

Ok(model_files)
}

async fn download_safetensors(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {
// Single file
tracing::info!("Downloading `model.safetensors`");
Expand Down Expand Up @@ -362,6 +394,7 @@ async fn download_safetensors(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {
Ok(safetensors_files)
}

#[cfg(feature = "ort")]
async fn download_onnx(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {
let mut model_files: Vec<PathBuf> = Vec::new();

Expand Down
27 changes: 21 additions & 6 deletions core/src/download.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use hf_hub::api::tokio::{ApiError, ApiRepo};
use std::path::PathBuf;
use text_embeddings_backend::download_weights;
use tracing::instrument;

// Old classes used other config names than 'sentence_bert_config.json'
Expand All @@ -15,20 +14,36 @@ pub const ST_CONFIG_NAMES: [&str; 7] = [
];

#[instrument(skip_all)]
pub async fn download_artifacts(api: &ApiRepo) -> Result<PathBuf, ApiError> {
pub async fn download_artifacts(api: &ApiRepo, pool_config: bool) -> Result<PathBuf, ApiError> {
let start = std::time::Instant::now();

tracing::info!("Starting download");

// Optionally download the pooling config.
if pool_config {
// If a pooling config exist, download it
let _ = download_pool_config(api).await.map_err(|err| {
tracing::warn!("Download failed: {err}");
err
});
}

// Download legacy sentence transformers config
// We don't warn on failure as it is a legacy file
let _ = download_st_config(api).await;
// Download new sentence transformers config
let _ = download_new_st_config(api).await.map_err(|err| {
tracing::warn!("Download failed: {err}");
err
});

tracing::info!("Downloading `config.json`");
api.get("config.json").await?;

tracing::info!("Downloading `tokenizer.json`");
api.get("tokenizer.json").await?;

let model_files = download_weights(api).await?;
let model_root = model_files[0].parent().unwrap().to_path_buf();
let tokenizer_path = api.get("tokenizer.json").await?;

let model_root = tokenizer_path.parent().unwrap().to_path_buf();
tracing::info!("Model artifacts downloaded in {:?}", start.elapsed());
Ok(model_root)
}
Expand Down
Loading
Loading