From 337758d39db91c20a2f699eb6958d5213648164a Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 14 Jan 2025 11:29:59 -0700 Subject: [PATCH 1/2] :bug: use right truncation for non-generative tasks Signed-off-by: Joe Runde --- vllm/config.py | 4 ++++ vllm/transformers_utils/tokenizer_group/__init__.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 59b509d5a961..7a7267dd8dc4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -357,6 +357,10 @@ def __init__(self, supported_tasks, task = self._resolve_task(task, self.hf_config) self.supported_tasks = supported_tasks self.task: Final = task + if self.task in ("draft", "generate"): + self.truncation_side = "left" + else: + self.truncation_side = "right" self.pooler_config = self._init_pooler_config(override_pooler_config) self.logits_processor_pattern = logits_processor_pattern diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index d40027679699..09569c564a58 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -24,7 +24,8 @@ def init_tokenizer_from_configs(model_config: ModelConfig, max_input_length=None, tokenizer_mode=model_config.tokenizer_mode, trust_remote_code=model_config.trust_remote_code, - revision=model_config.tokenizer_revision) + revision=model_config.tokenizer_revision, + truncation_side=model_config.truncation_side) return get_tokenizer_group(parallel_config.tokenizer_pool_config, **init_kwargs) From 4ce26730b259063cf5960217ef3773e93bb34e27 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 14 Jan 2025 11:50:50 -0700 Subject: [PATCH 2/2] :test_tube: add simple truncation test Signed-off-by: Joe Runde --- tests/entrypoints/llm/test_encode.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/llm/test_encode.py index 41163809237e..3906ad766e0b 100644 --- a/tests/entrypoints/llm/test_encode.py +++ b/tests/entrypoints/llm/test_encode.py @@ -105,3 +105,10 @@ def test_multiple_pooling_params(llm: LLM): # pooling_params is None, default params should be applied outputs = llm.encode(PROMPTS, pooling_params=None) assert len(PROMPTS) == len(outputs) + + +@pytest.mark.skip_global_cleanup +def test_right_side_truncation(llm: LLM): + # Embeddings models should truncate the end of the prompt + tokenizer = llm.get_tokenizer() + assert tokenizer.truncation_side == "right"