From 7ed87cb8cd3ceb4a8713b8bf9fac189fe289da7f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 9 Apr 2025 12:35:18 +0200 Subject: [PATCH 1/4] Fixing the tokenization routes token (offsets are in bytes, not in char with the regular encode method). --- core/src/tokenization.rs | 35 +++++++++++++++++++++++++++++++++++ router/src/http/server.rs | 3 ++- 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/core/src/tokenization.rs b/core/src/tokenization.rs index 71617daf..25abbc88 100644 --- a/core/src/tokenization.rs +++ b/core/src/tokenization.rs @@ -485,3 +485,38 @@ enum TokenizerRequest { Span, ), } + +#[cfg(test)] +mod tests { + use super::*; + use hf_hub::api::sync::ApiBuilder; + + #[test] + fn tokenizer() { + let api = ApiBuilder::from_env().build().unwrap(); + let filename = api + .model("BAAI/bge-m3".to_string()) + .get("tokenizer.json") + .unwrap(); + let string = "这是一个文本向量化的测试句子"; + let tokenizer = Tokenizer::from_file(filename).unwrap(); + + let encoded = tokenizer.encode(string, true).unwrap(); + assert_eq!( + encoded.get_offsets(), + vec![ + (0, 0), + (0, 3), + (0, 12), + (12, 18), + (18, 21), + (21, 24), + (24, 30), + (30, 36), + (36, 39), + (39, 42), + (0, 0) + ] + ); + } +} diff --git a/router/src/http/server.rs b/router/src/http/server.rs index cadb6c18..902333f3 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs @@ -1312,7 +1312,8 @@ async fn tokenize( stop: None, }, false => { - let text: String = input.chars().skip(start).take(stop - start).collect(); + let text: Vec = input.bytes().skip(start).take(stop - start).collect(); + let text: String = String::from_utf8_lossy(&text).to_string(); SimpleToken { id, text, From 63d9fde31deaab0e1e5b0f495bfbd606c893aa04 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 9 Apr 2025 15:15:52 +0200 Subject: [PATCH 2/4] Fixing the tokenization in both routes. --- core/src/tokenization.rs | 127 ++++++++++++++++++++++++++++++++++++++ router/src/grpc/server.rs | 46 ++++++-------- router/src/http/server.rs | 44 +++++-------- 3 files changed, 163 insertions(+), 54 deletions(-) diff --git a/core/src/tokenization.rs b/core/src/tokenization.rs index 25abbc88..7636afa8 100644 --- a/core/src/tokenization.rs +++ b/core/src/tokenization.rs @@ -16,6 +16,16 @@ pub struct Tokenization { sender: async_channel::Sender, } +#[derive(Debug)] +#[cfg_attr(test, derive(PartialEq))] +pub struct SimpleToken { + pub id: u32, + pub text: String, + pub special: bool, + pub start: Option, + pub stop: Option, +} + impl Tokenization { pub fn new( workers: usize, @@ -486,6 +496,39 @@ enum TokenizerRequest { ), } +pub fn into_tokens(encoding: tokenizers::Encoding, input: &str) -> Vec { + encoding + .get_ids() + .iter() + .zip(encoding.get_offsets()) + .zip(encoding.get_special_tokens_mask()) + .zip(encoding.get_tokens()) + .map(|(((&id, &(start, stop)), special), token)| { + let special = *special == 1; + match special { + true => SimpleToken { + id, + text: token.clone(), + special, + start: None, + stop: None, + }, + false => { + let text: Vec = input.bytes().skip(start).take(stop - start).collect(); + let text: String = String::from_utf8_lossy(&text).to_string(); + SimpleToken { + id, + text, + special, + start: Some(start), + stop: Some(stop), + } + } + } + }) + .collect() +} + #[cfg(test)] mod tests { use super::*; @@ -518,5 +561,89 @@ mod tests { (0, 0) ] ); + + let tokens = into_tokens(encoded, &string); + assert_eq!( + tokens, + vec![ + SimpleToken { + id: 0, + text: "".to_string(), + special: true, + start: None, + stop: None + }, + SimpleToken { + id: 6, + text: "这".to_string(), + special: false, + start: Some(0), + stop: Some(3) + }, + SimpleToken { + id: 100013, + text: "这是一个".to_string(), + special: false, + start: Some(0), + stop: Some(12) + }, + SimpleToken { + id: 189061, + text: "文本".to_string(), + special: false, + start: Some(12), + stop: Some(18) + }, + SimpleToken { + id: 2110, + text: "向".to_string(), + special: false, + start: Some(18), + stop: Some(21) + }, + SimpleToken { + id: 3272, + text: "量".to_string(), + special: false, + start: Some(21), + stop: Some(24) + }, + SimpleToken { + id: 41904, + text: "化的".to_string(), + special: false, + start: Some(24), + stop: Some(30) + }, + SimpleToken { + id: 49125, + text: "测试".to_string(), + special: false, + start: Some(30), + stop: Some(36) + }, + SimpleToken { + id: 27683, + text: "句".to_string(), + special: false, + start: Some(36), + stop: Some(39) + }, + SimpleToken { + id: 1344, + text: "子".to_string(), + special: false, + start: Some(39), + stop: Some(42) + }, + SimpleToken { + id: 2, + text: "".to_string(), + special: true, + start: None, + stop: None + } + ] + ); } } diff --git a/router/src/grpc/server.rs b/router/src/grpc/server.rs index 8de706dd..f0666aa5 100644 --- a/router/src/grpc/server.rs +++ b/router/src/grpc/server.rs @@ -15,7 +15,9 @@ use std::future::Future; use std::net::SocketAddr; use std::time::{Duration, Instant}; use text_embeddings_core::infer::Infer; -use text_embeddings_core::tokenization::EncodingInput; +use text_embeddings_core::tokenization::{ + into_tokens, EncodingInput, SimpleToken as CoreSimpleToken, +}; use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit}; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::StreamExt; @@ -340,32 +342,22 @@ impl TextEmbeddingsService { .map_err(ErrorResponse::from)?; let inputs = encoded_inputs.unwrap_or(inputs); - let tokens: Vec = encoding - .get_ids() - .iter() - .zip(encoding.get_offsets()) - .zip(encoding.get_special_tokens_mask()) - .zip(encoding.get_tokens()) - .map(|(((&id, &(start, stop)), special), token)| { - let special = *special == 1; - match special { - true => SimpleToken { - id, - text: token.clone(), - special, - start: None, - stop: None, - }, - false => { - let text: String = inputs.chars().skip(start).take(stop - start).collect(); - SimpleToken { - id, - text, - special, - start: Some(start as u32), - stop: Some(stop as u32), - } - } + let tokens: Vec = into_tokens(encoding, &input) + .into_iter() + .map(|t| { + let CoreSimpleToken { + id, + text, + special, + start, + stop, + } = t; + SimpleToken { + id, + text, + special, + start, + stop, } }) .collect(); diff --git a/router/src/http/server.rs b/router/src/http/server.rs index 902333f3..2abd81e9 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs @@ -34,6 +34,7 @@ use text_embeddings_backend::BackendError; use text_embeddings_core::infer::{ AllEmbeddingsInferResponse, Infer, InferMetadata, PooledEmbeddingsInferResponse, }; +use text_embeddings_core::tokenization::{into_tokens, SimpleToken as CoreSimpleToken}; use text_embeddings_core::TextEmbeddingsError; use tokio::sync::OwnedSemaphorePermit; use tower_http::cors::{AllowOrigin, CorsLayer}; @@ -1295,33 +1296,22 @@ async fn tokenize( .map_err(ErrorResponse::from)?; let input = encoded_input.unwrap_or(input); - let tokens: Vec = encoding - .get_ids() - .iter() - .zip(encoding.get_offsets()) - .zip(encoding.get_special_tokens_mask()) - .zip(encoding.get_tokens()) - .map(|(((&id, &(start, stop)), special), token)| { - let special = *special == 1; - match special { - true => SimpleToken { - id, - text: token.clone(), - special, - start: None, - stop: None, - }, - false => { - let text: Vec = input.bytes().skip(start).take(stop - start).collect(); - let text: String = String::from_utf8_lossy(&text).to_string(); - SimpleToken { - id, - text, - special, - start: Some(start), - stop: Some(stop), - } - } + let tokens: Vec = into_tokens(encoding, &input) + .into_iter() + .map(|t| { + let CoreSimpleToken { + id, + text, + special, + start, + stop, + } = t; + SimpleToken { + id, + text, + special, + start, + stop, } }) .collect(); From 5a1149d52601adbc1888f6125e40347e5ed4e5dd Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 9 Apr 2025 15:31:17 +0200 Subject: [PATCH 3/4] Typo. --- router/src/grpc/server.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/router/src/grpc/server.rs b/router/src/grpc/server.rs index f0666aa5..09db37f3 100644 --- a/router/src/grpc/server.rs +++ b/router/src/grpc/server.rs @@ -342,7 +342,7 @@ impl TextEmbeddingsService { .map_err(ErrorResponse::from)?; let inputs = encoded_inputs.unwrap_or(inputs); - let tokens: Vec = into_tokens(encoding, &input) + let tokens: Vec = into_tokens(encoding, &inputs) .into_iter() .map(|t| { let CoreSimpleToken { From 46dd4ec1ae0f0af9ae5fe6cde55e54e33e23857d Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 9 Apr 2025 16:21:55 +0200 Subject: [PATCH 4/4] Casting. --- router/src/grpc/server.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/router/src/grpc/server.rs b/router/src/grpc/server.rs index 09db37f3..3c98f8b8 100644 --- a/router/src/grpc/server.rs +++ b/router/src/grpc/server.rs @@ -356,8 +356,8 @@ impl TextEmbeddingsService { id, text, special, - start, - stop, + start: start.map(|s| s as u32), + stop: stop.map(|s| s as u32), } }) .collect();