Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 21 additions & 16 deletions router/src/http/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ async fn predict(
))
};

let truncate = req.truncate.unwrap_or(info.auto_truncate);

let (response, metadata) = match req.inputs {
PredictInput::Single(inputs) => {
metrics::increment_counter!("te_request_count", "method" => "single");
Expand All @@ -159,7 +161,7 @@ async fn predict(
let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
let (prompt_tokens, tokenization, queue, inference, predictions) = predict_inner(
inputs,
req.truncate.unwrap_or(info.auto_truncate),
truncate,
req.raw_scores,
infer.0,
info.0,
Expand Down Expand Up @@ -208,7 +210,7 @@ async fn predict(
let local_info = info.clone();
futures.push(predict_inner(
input,
req.truncate.unwrap_or(info.auto_truncate),
truncate,
req.raw_scores,
local_infer.0,
local_info.0,
Expand Down Expand Up @@ -342,6 +344,8 @@ async fn rerank(
))
};

let truncate = req.truncate.unwrap_or(info.auto_truncate);

let (response, metadata) = {
metrics::increment_counter!("te_request_count", "method" => "batch");

Expand Down Expand Up @@ -370,7 +374,7 @@ async fn rerank(
futures.push(rerank_inner(
req.query.clone(),
text.clone(),
req.truncate.unwrap_or(info.auto_truncate),
truncate,
req.raw_scores,
local_infer.0,
))
Expand Down Expand Up @@ -470,6 +474,8 @@ async fn embed(
let span = tracing::Span::current();
let start_time = Instant::now();

let truncate = req.truncate.unwrap_or(info.auto_truncate);

let (response, metadata) = match req.inputs {
Input::Single(input) => {
metrics::increment_counter!("te_request_count", "method" => "single");
Expand All @@ -478,12 +484,7 @@ async fn embed(

let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
let response = infer
.embed_pooled(
input,
req.truncate.unwrap_or(info.auto_truncate),
req.normalize,
permit,
)
.embed_pooled(input, truncate, req.normalize, permit)
.await
.map_err(ErrorResponse::from)?;

Expand Down Expand Up @@ -536,7 +537,6 @@ async fn embed(
for input in inputs {
compute_chars += input.count_chars();

let truncate = req.truncate.unwrap_or(info.auto_truncate);
let local_infer = infer.clone();
futures.push(async move {
let permit = local_infer.acquire_permit().await;
Expand Down Expand Up @@ -631,6 +631,7 @@ async fn embed_sparse(
}
sparse_values
};
let truncate = req.truncate.unwrap_or(info.auto_truncate);

let (response, metadata) = match req.inputs {
Input::Single(input) => {
Expand All @@ -640,7 +641,7 @@ async fn embed_sparse(

let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
let response = infer
.embed_sparse(input, req.truncate.unwrap_or(info.auto_truncate), permit)
.embed_sparse(input, truncate, permit)
.await
.map_err(ErrorResponse::from)?;

Expand Down Expand Up @@ -693,7 +694,6 @@ async fn embed_sparse(
for input in inputs {
compute_chars += input.count_chars();

let truncate = req.truncate.unwrap_or(info.auto_truncate);
let local_infer = infer.clone();
futures.push(async move {
let permit = local_infer.acquire_permit().await;
Expand Down Expand Up @@ -779,6 +779,8 @@ async fn embed_all(
let span = tracing::Span::current();
let start_time = Instant::now();

let truncate = req.truncate.unwrap_or(info.auto_truncate);

let (response, metadata) = match req.inputs {
Input::Single(input) => {
metrics::increment_counter!("te_request_count", "method" => "single");
Expand All @@ -787,7 +789,7 @@ async fn embed_all(

let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
let response = infer
.embed_all(input, req.truncate.unwrap_or(info.auto_truncate), permit)
.embed_all(input, truncate, permit)
.await
.map_err(ErrorResponse::from)?;

Expand Down Expand Up @@ -840,7 +842,6 @@ async fn embed_all(
for input in inputs {
compute_chars += input.count_chars();

let truncate = req.truncate.unwrap_or(info.auto_truncate);
let local_infer = infer.clone();
futures.push(async move {
let permit = local_infer.acquire_permit().await;
Expand Down Expand Up @@ -925,6 +926,8 @@ async fn openai_embed(
let span = tracing::Span::current();
let start_time = Instant::now();

let truncate = info.auto_truncate;

let (embeddings, metadata) = match req.input {
Input::Single(input) => {
metrics::increment_counter!("te_request_count", "method" => "single");
Expand All @@ -933,7 +936,7 @@ async fn openai_embed(

let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
let response = infer
.embed_pooled(input, false, true, permit)
.embed_pooled(input, truncate, true, permit)
.await
.map_err(ErrorResponse::from)?;

Expand Down Expand Up @@ -993,7 +996,9 @@ async fn openai_embed(
let local_infer = infer.clone();
futures.push(async move {
let permit = local_infer.acquire_permit().await;
local_infer.embed_pooled(input, false, true, permit).await
local_infer
.embed_pooled(input, truncate, true, permit)
.await
})
}
let results = join_all(futures)
Expand Down