diff --git a/Cargo.lock b/Cargo.lock index a42181cb..349be87b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3741,6 +3741,15 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fa42c91313f1d05da9b26f267f931cf178d4aba455b4c4622dd7355eb80c6640" +[[package]] +name = "simsimd" +version = "4.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efc843bc8f12d9c8e6b734a0fe8918fc497b42f6ae0f347dbfdad5b5138ab9b4" +dependencies = [ + "cc", +] + [[package]] name = "sketches-ddsketch" version = "0.2.2" @@ -4038,6 +4047,7 @@ dependencies = [ "reqwest 0.12.5", "serde", "serde_json", + "simsimd", "text-embeddings-backend", "text-embeddings-core", "thiserror", diff --git a/docs/openapi.json b/docs/openapi.json index b90189e9..7cc9d50d 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -565,6 +565,93 @@ } } }, + "/similarity": { + "post": { + "tags": [ + "Text Embeddings Inference" + ], + "summary": "Get Sentence Similarity. Returns a 424 status code if the model is not an embedding model.", + "operationId": "similarity", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SimilarityRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Sentence Similarity", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SimilarityResponse" + } + } + } + }, + "413": { + "description": "Batch size error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Batch size error", + "error_type": "validation" + } + } + } + }, + "422": { + "description": "Tokenization error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Tokenization error", + "error_type": "tokenizer" + } + } + } + }, + "424": { + "description": "Embedding Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Inference failed", + "error_type": "backend" + } + } + } + }, + "429": { + "description": "Model is overloaded", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Model is overloaded", + "error_type": "overloaded" + } + } + } + } + } + } + }, "/tokenize": { "post": { "tags": [ @@ -1441,6 +1528,91 @@ "$ref": "#/components/schemas/Rank" } }, + "SimilarityInput": { + "type": "object", + "required": [ + "source_sentence", + "sentences" + ], + "properties": { + "sentences": { + "type": "array", + "items": { + "type": "string" + }, + "description": "A list of strings which will be compared against the source_sentence.", + "example": [ + "What is Machine Learning?" + ] + }, + "source_sentence": { + "type": "string", + "description": "The string that you wish to compare the other strings with. This can be a phrase, sentence,\nor longer passage, depending on the model being used.", + "example": "What is Deep Learning?" + } + } + }, + "SimilarityParameters": { + "type": "object", + "required": [ + "truncation_direction" + ], + "properties": { + "prompt_name": { + "type": "string", + "description": "The name of the prompt that should be used by for encoding. If not set, no prompt\nwill be applied.\n\nMust be a key in the `Sentence Transformers` configuration `prompts` dictionary.\n\nFor example if ``prompt_name`` is \"query\" and the ``prompts`` is {\"query\": \"query: \", ...},\nthen the sentence \"What is the capital of France?\" will be encoded as\n\"query: What is the capital of France?\" because the prompt text will be prepended before\nany text to encode.", + "default": "null", + "example": "null", + "nullable": true + }, + "truncate": { + "type": "boolean", + "default": "false", + "example": "false", + "nullable": true + }, + "truncation_direction": { + "allOf": [ + { + "$ref": "#/components/schemas/TruncationDirection" + } + ], + "default": "right" + } + } + }, + "SimilarityRequest": { + "type": "object", + "required": [ + "inputs" + ], + "properties": { + "inputs": { + "$ref": "#/components/schemas/SimilarityInput" + }, + "parameters": { + "allOf": [ + { + "$ref": "#/components/schemas/SimilarityParameters" + } + ], + "default": "null", + "nullable": true + } + } + }, + "SimilarityResponse": { + "type": "array", + "items": { + "type": "number", + "format": "float" + }, + "example": [ + 0.0, + 1.0, + 0.5 + ] + }, "SimpleToken": { "type": "object", "required": [ diff --git a/router/Cargo.toml b/router/Cargo.toml index 44e56015..2fe52cbc 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -30,6 +30,7 @@ opentelemetry = "0.23.0" opentelemetry_sdk = { version = "0.23.0", features = ["rt-tokio"] } opentelemetry-otlp = "0.16.0" reqwest = { version = "0.12.5", features = [] } +simsimd = "4.4.0" serde = { workspace = true } serde_json = { workspace = true } thiserror = { workspace = true } diff --git a/router/src/http/server.rs b/router/src/http/server.rs index 35f5c1fb..040ac070 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs @@ -4,7 +4,8 @@ use crate::http::types::{ EmbedSparseRequest, EmbedSparseResponse, Embedding, EncodingFormat, Input, InputIds, InputType, OpenAICompatEmbedding, OpenAICompatErrorResponse, OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage, PredictInput, PredictRequest, PredictResponse, Prediction, Rank, - RerankRequest, RerankResponse, Sequence, SimpleToken, SparseValue, TokenizeInput, + RerankRequest, RerankResponse, Sequence, SimilarityInput, SimilarityParameters, + SimilarityRequest, SimilarityResponse, SimpleToken, SparseValue, TokenizeInput, TokenizeRequest, TokenizeResponse, TruncationDirection, VertexPrediction, VertexRequest, VertexResponse, }; @@ -26,6 +27,7 @@ use futures::future::join_all; use futures::FutureExt; use http::header::AUTHORIZATION; use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle}; +use simsimd::SpatialSimilarity; use std::net::SocketAddr; use std::time::{Duration, Instant}; use text_embeddings_backend::BackendError; @@ -455,6 +457,88 @@ async fn rerank( Ok((headers, Json(response))) } +/// Get Sentence Similarity. Returns a 424 status code if the model is not an embedding model. +#[utoipa::path( +post, +tag = "Text Embeddings Inference", +path = "/similarity", +request_body = SimilarityRequest, +responses( +(status = 200, description = "Sentence Similarity", body = SimilarityResponse), +(status = 424, description = "Embedding Error", body = ErrorResponse, +example = json ! ({"error": "Inference failed", "error_type": "backend"})), +(status = 429, description = "Model is overloaded", body = ErrorResponse, +example = json ! ({"error": "Model is overloaded", "error_type": "overloaded"})), +(status = 422, description = "Tokenization error", body = ErrorResponse, +example = json ! ({"error": "Tokenization error", "error_type": "tokenizer"})), +(status = 413, description = "Batch size error", body = ErrorResponse, +example = json ! ({"error": "Batch size error", "error_type": "validation"})), +) +)] +#[instrument( + skip_all, + fields(total_time, tokenization_time, queue_time, inference_time,) +)] +async fn similarity( + infer: Extension, + info: Extension, + Json(req): Json, +) -> Result<(HeaderMap, Json), (StatusCode, Json)> { + if req.inputs.sentences.is_empty() { + let message = "`inputs.sentences` cannot be empty".to_string(); + tracing::error!("{message}"); + let err = ErrorResponse { + error: message, + error_type: ErrorType::Validation, + }; + let counter = metrics::counter!("te_request_failure", "err" => "validation"); + counter.increment(1); + Err(err)?; + } + // +1 because of the source sentence + let batch_size = req.inputs.sentences.len() + 1; + if batch_size > info.max_client_batch_size { + let message = format!( + "batch size {batch_size} > maximum allowed batch size {}", + info.max_client_batch_size + ); + tracing::error!("{message}"); + let err = ErrorResponse { + error: message, + error_type: ErrorType::Validation, + }; + let counter = metrics::counter!("te_request_failure", "err" => "batch_size"); + counter.increment(1); + Err(err)?; + } + + // Convert request to embed request + let mut inputs = Vec::with_capacity(req.inputs.sentences.len() + 1); + inputs.push(InputType::String(req.inputs.source_sentence)); + for s in req.inputs.sentences { + inputs.push(InputType::String(s)); + } + let parameters = req.parameters.unwrap_or_default(); + let embed_req = EmbedRequest { + inputs: Input::Batch(inputs), + truncate: parameters.truncate, + truncation_direction: parameters.truncation_direction, + prompt_name: parameters.prompt_name, + normalize: false, + }; + + // Get embeddings + let (header_map, embed_response) = embed(infer, info, Json(embed_req)).await?; + let embeddings = embed_response.0 .0; + + // Compute cosine + let distances = (1..batch_size) + .map(|i| 1.0 - f32::cosine(&embeddings[0], &embeddings[i]).unwrap() as f32) + .collect(); + + Ok((header_map, Json(SimilarityResponse(distances)))) +} + /// Get Embeddings. Returns a 424 status code if the model is not an embedding model. #[utoipa::path( post, @@ -1472,6 +1556,7 @@ pub async fn run( embed_all, embed_sparse, openai_embed, + similarity, tokenize, decode, metrics, @@ -1509,6 +1594,10 @@ pub async fn run( TokenizeRequest, TokenizeResponse, TruncationDirection, + SimilarityInput, + SimilarityParameters, + SimilarityRequest, + SimilarityResponse, SimpleToken, InputType, InputIds, @@ -1587,6 +1676,7 @@ pub async fn run( .route("/embed_sparse", post(embed_sparse)) .route("/predict", post(predict)) .route("/rerank", post(rerank)) + .route("/similarity", post(similarity)) .route("/tokenize", post(tokenize)) .route("/decode", post(decode)) // OpenAI compat route @@ -1634,7 +1724,11 @@ pub async fn run( .route("/invocations", post(rerank)) } ModelType::Embedding(model) => { - if model.pooling == "splade" { + if std::env::var("TASK").ok() == Some("sentence-similarity".to_string()) { + app.route("/", post(similarity)) + // AWS Sagemaker route + .route("/invocations", post(similarity)) + } else if model.pooling == "splade" { app.route("/", post(embed_sparse)) // AWS Sagemaker route .route("/invocations", post(embed_sparse)) diff --git a/router/src/http/types.rs b/router/src/http/types.rs index 4414ecb4..486f8cb7 100644 --- a/router/src/http/types.rs +++ b/router/src/http/types.rs @@ -360,6 +360,48 @@ pub(crate) struct OpenAICompatResponse { pub usage: OpenAICompatUsage, } +#[derive(Deserialize, ToSchema)] +pub(crate) struct SimilarityInput { + /// The string that you wish to compare the other strings with. This can be a phrase, sentence, + /// or longer passage, depending on the model being used. + #[schema(example = "What is Deep Learning?")] + pub source_sentence: String, + /// A list of strings which will be compared against the source_sentence. + #[schema(example = json!(["What is Machine Learning?"]))] + pub sentences: Vec, +} + +#[derive(Deserialize, ToSchema, Default)] +pub(crate) struct SimilarityParameters { + #[schema(default = "false", example = "false", nullable = true)] + pub truncate: Option, + #[schema(default = "right", example = "right")] + pub truncation_direction: TruncationDirection, + /// The name of the prompt that should be used by for encoding. If not set, no prompt + /// will be applied. + /// + /// Must be a key in the `Sentence Transformers` configuration `prompts` dictionary. + /// + /// For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ", ...}, + /// then the sentence "What is the capital of France?" will be encoded as + /// "query: What is the capital of France?" because the prompt text will be prepended before + /// any text to encode. + #[schema(default = "null", example = "null", nullable = true)] + pub prompt_name: Option, +} + +#[derive(Deserialize, ToSchema)] +pub(crate) struct SimilarityRequest { + pub inputs: SimilarityInput, + /// Additional inference parameters for Sentence Similarity + #[schema(default = "null", example = "null", nullable = true)] + pub parameters: Option, +} + +#[derive(Serialize, ToSchema)] +#[schema(example = json!([0.0, 1.0, 0.5]))] +pub(crate) struct SimilarityResponse(pub Vec); + #[derive(Deserialize, ToSchema)] pub(crate) struct EmbedRequest { pub inputs: Input,