From 95899ddfb3d8d524ff664e04efbbed2c30f101be Mon Sep 17 00:00:00 2001 From: Andy W Date: Thu, 19 Sep 2024 23:35:15 -0400 Subject: [PATCH 01/15] re-formatted --- vllm/entrypoints/llm.py | 101 ++++++++++++++++++++++++---------------- 1 file changed, 60 insertions(+), 41 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 248b070611cd..25ddcc6c0803 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -351,7 +351,8 @@ def generate( def chat( self, - messages: List[ChatCompletionMessageParam], + conversations: Union[List[ChatCompletionMessageParam], + List[List[ChatCompletionMessageParam]]], sampling_params: Optional[Union[SamplingParams, List[SamplingParams]]] = None, use_tqdm: bool = True, @@ -359,7 +360,7 @@ def chat( chat_template: Optional[str] = None, add_generation_prompt: bool = True, tools: Optional[List[Dict[str, Any]]] = None, - ) -> List[RequestOutput]: + ) -> Union[List[List[RequestOutput]], List[RequestOutput]]: """ Generate responses for a chat conversation. @@ -371,8 +372,9 @@ def chat( to the OpenAI API. Args: - messages: A single conversation represented as a list of messages. - Each message is a dictionary with 'role' and 'content' keys. + conversations: A list or a single conversation represented as a list + of messages. Each message is a dictionary with 'role' and + 'content' keys. sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. When it @@ -386,49 +388,66 @@ def chat( to each message. Returns: - A list of ``RequestOutput`` objects containing the generated - responses in the same order as the input messages. + A list of lists or single list of ``RequestOutput`` objects + containing the generated responses in the same order as the input + conversations and messages. """ + list_of_conversations: List[List[ChatCompletionMessageParam]] - tokenizer = self.get_tokenizer() - model_config = self.llm_engine.get_model_config() - - conversation, mm_data = parse_chat_messages(messages, model_config, - tokenizer) - - prompt: Union[str, List[int]] - if isinstance(tokenizer, MistralTokenizer): - prompt = apply_mistral_chat_template( - tokenizer, - messages=messages, - chat_template=chat_template, - add_generation_prompt=add_generation_prompt, - tools=tools, - ) + # Handle multi and single conversations + if is_list_of(conversations, list): + # conversations is List[List[...]] + list_of_conversations = conversations else: - prompt = apply_hf_chat_template( - tokenizer, - conversation=conversation, - chat_template=chat_template, - add_generation_prompt=add_generation_prompt, - tools=tools, - ) + # conversations is List[...] + list_of_conversations = [conversations] + + outputs = [] + + for messages in list_of_conversations: + tokenizer = self.get_tokenizer() + model_config = self.llm_engine.get_model_config() + + conversation, mm_data = parse_chat_messages( + messages, model_config, tokenizer) + + prompt: Union[str, List[int]] + if isinstance(tokenizer, MistralTokenizer): + prompt = apply_mistral_chat_template( + tokenizer, + messages=messages, + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + tools=tools, + ) + else: + prompt = apply_hf_chat_template( + tokenizer, + conversation=conversation, + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + tools=tools, + ) + + inputs: PromptInputs + if is_list_of(prompt, int): + inputs = TokensPrompt(prompt_token_ids=prompt) + else: + inputs = TextPrompt(prompt=prompt) - inputs: PromptInputs - if is_list_of(prompt, int): - inputs = TokensPrompt(prompt_token_ids=prompt) - else: - inputs = TextPrompt(prompt=prompt) + if mm_data is not None: + inputs["multi_modal_data"] = mm_data - if mm_data is not None: - inputs["multi_modal_data"] = mm_data + out = self.generate( + inputs, + sampling_params=sampling_params, + use_tqdm=use_tqdm, + lora_request=lora_request, + ) + outputs.append(out) - return self.generate( - inputs, - sampling_params=sampling_params, - use_tqdm=use_tqdm, - lora_request=lora_request, - ) + # When conversations is List[...], return a single list. + return outputs if len(outputs) > 1 else outputs[0] @overload # LEGACY: single (prompt + optional token ids) def encode( From f7ed4a3b5977f5f0c5d7c6198fd356c33e5ef0af Mon Sep 17 00:00:00 2001 From: Andy <37781802+aandyw@users.noreply.github.com> Date: Thu, 19 Sep 2024 23:57:50 -0400 Subject: [PATCH 02/15] Update vllm/entrypoints/llm.py Co-authored-by: Cyrus Leung --- vllm/entrypoints/llm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 25ddcc6c0803..1e0661666786 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -372,9 +372,9 @@ def chat( to the OpenAI API. Args: - conversations: A list or a single conversation represented as a list - of messages. Each message is a dictionary with 'role' and - 'content' keys. + conversations: A list of conversations or a single conversation. + - Each conversation is represented as a list of messages. + - Each message is a dictionary with 'role' and 'content' keys. sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. When it From 403bfdd0bb5dc177d6303da19c8148c8f0d89983 Mon Sep 17 00:00:00 2001 From: Andy W Date: Fri, 20 Sep 2024 00:56:31 -0400 Subject: [PATCH 03/15] added overloads --- vllm/entrypoints/llm.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 1e0661666786..8d79c60417dc 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -349,6 +349,34 @@ def generate( outputs = self._run_engine(use_tqdm=use_tqdm) return LLMEngine.validate_outputs(outputs, RequestOutput) + @overload + def chat( + self, + conversations: List[ChatCompletionMessageParam], + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + chat_template: Optional[str] = None, + add_generation_prompt: bool = True, + tools: Optional[List[Dict[str, Any]]] = None, + ) -> List[RequestOutput]: + ... + + @overload + def chat( + self, + conversations: List[List[ChatCompletionMessageParam]], + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + chat_template: Optional[str] = None, + add_generation_prompt: bool = True, + tools: Optional[List[Dict[str, Any]]] = None, + ) -> List[List[RequestOutput]]: + ... + def chat( self, conversations: Union[List[ChatCompletionMessageParam], @@ -360,7 +388,7 @@ def chat( chat_template: Optional[str] = None, add_generation_prompt: bool = True, tools: Optional[List[Dict[str, Any]]] = None, - ) -> Union[List[List[RequestOutput]], List[RequestOutput]]: + ) -> Union[List[RequestOutput], List[List[RequestOutput]]]: """ Generate responses for a chat conversation. @@ -446,7 +474,7 @@ def chat( ) outputs.append(out) - # When conversations is List[...], return a single list. + # When conversations is List[...], return a single list return outputs if len(outputs) > 1 else outputs[0] @overload # LEGACY: single (prompt + optional token ids) From 28eba3554862d939cbc4f85c0469f33c21eb13b0 Mon Sep 17 00:00:00 2001 From: Andy W Date: Fri, 20 Sep 2024 20:51:30 -0400 Subject: [PATCH 04/15] changed conversations -> messages --- vllm/entrypoints/llm.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 8d79c60417dc..0185eb08d92b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -352,7 +352,7 @@ def generate( @overload def chat( self, - conversations: List[ChatCompletionMessageParam], + messages: List[ChatCompletionMessageParam], sampling_params: Optional[Union[SamplingParams, List[SamplingParams]]] = None, use_tqdm: bool = True, @@ -366,7 +366,7 @@ def chat( @overload def chat( self, - conversations: List[List[ChatCompletionMessageParam]], + messages: List[List[ChatCompletionMessageParam]], sampling_params: Optional[Union[SamplingParams, List[SamplingParams]]] = None, use_tqdm: bool = True, @@ -379,8 +379,8 @@ def chat( def chat( self, - conversations: Union[List[ChatCompletionMessageParam], - List[List[ChatCompletionMessageParam]]], + messages: Union[List[ChatCompletionMessageParam], + List[List[ChatCompletionMessageParam]]], sampling_params: Optional[Union[SamplingParams, List[SamplingParams]]] = None, use_tqdm: bool = True, @@ -400,7 +400,7 @@ def chat( to the OpenAI API. Args: - conversations: A list of conversations or a single conversation. + messages: A list of conversations or a single conversation. - Each conversation is represented as a list of messages. - Each message is a dictionary with 'role' and 'content' keys. sampling_params: The sampling parameters for text generation. @@ -420,30 +420,30 @@ def chat( containing the generated responses in the same order as the input conversations and messages. """ - list_of_conversations: List[List[ChatCompletionMessageParam]] + list_of_messages: List[List[ChatCompletionMessageParam]] # Handle multi and single conversations - if is_list_of(conversations, list): + if is_list_of(messages, list): # conversations is List[List[...]] - list_of_conversations = conversations + list_of_messages = messages else: # conversations is List[...] - list_of_conversations = [conversations] + list_of_messages = [messages] outputs = [] - for messages in list_of_conversations: + for msgs in list_of_messages: tokenizer = self.get_tokenizer() model_config = self.llm_engine.get_model_config() conversation, mm_data = parse_chat_messages( - messages, model_config, tokenizer) + msgs, model_config, tokenizer) prompt: Union[str, List[int]] if isinstance(tokenizer, MistralTokenizer): prompt = apply_mistral_chat_template( tokenizer, - messages=messages, + messages=msgs, chat_template=chat_template, add_generation_prompt=add_generation_prompt, tools=tools, From 148c4ca12729d1f3c3c34ce6e2a9af013c85ee83 Mon Sep 17 00:00:00 2001 From: Andy W Date: Fri, 20 Sep 2024 21:26:47 -0400 Subject: [PATCH 05/15] added multi chat unittest --- tests/entrypoints/llm/test_generate.py | 34 ++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index ef34bebbb0f8..d82334acb590 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -162,6 +162,40 @@ def test_chat(): assert len(outputs) == 1 +def test_multi_chat(): + + llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct") + + prompt1 = "Explain the concept of entropy." + prompt2 = "Explain what among us is." + + messages = [ + [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": prompt1 + }, + ], + [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": prompt2 + }, + ] + ] + + outputs = llm.chat(messages) + assert len(outputs) == 2 + + @pytest.mark.parametrize("image_urls", [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]]) def test_chat_multi_image(image_urls: List[str]): From 71b623177ffdaad6e9f7cdc471d49828a2b1cb32 Mon Sep 17 00:00:00 2001 From: Andy W Date: Fri, 20 Sep 2024 23:41:56 -0400 Subject: [PATCH 06/15] update --- tests/entrypoints/llm/test_generate.py | 43 +++++++++++++------------- vllm/entrypoints/llm.py | 6 ++-- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index d82334acb590..cd989225e248 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -169,29 +169,30 @@ def test_multi_chat(): prompt1 = "Explain the concept of entropy." prompt2 = "Explain what among us is." - messages = [ - [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": prompt1 - }, - ], - [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": prompt2 - }, - ] + conversation1 = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": prompt1 + }, + ] + + conversation2 = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": prompt2 + }, ] + messages = [conversation1, conversation2] + outputs = llm.chat(messages) assert len(outputs) == 2 diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 0185eb08d92b..2badbda00c53 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -424,10 +424,10 @@ def chat( # Handle multi and single conversations if is_list_of(messages, list): - # conversations is List[List[...]] + # messages is List[List[...]] list_of_messages = messages else: - # conversations is List[...] + # messages is List[...] list_of_messages = [messages] outputs = [] @@ -474,7 +474,7 @@ def chat( ) outputs.append(out) - # When conversations is List[...], return a single list + # When messages is List[...], return a single list return outputs if len(outputs) > 1 else outputs[0] @overload # LEGACY: single (prompt + optional token ids) From ec011b41c3d7349c3604d7a2dc69b6c3f69ceda2 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 21 Sep 2024 11:51:05 +0800 Subject: [PATCH 07/15] Fix outdated name --- vllm/entrypoints/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index d9d4952aa0ed..fbeb781d3b0f 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -460,7 +460,7 @@ def chat( tools=tools, ) - prompt: PromptInputs + prompt: PromptType if is_list_of(prompt_data, int): prompt = TokensPrompt(prompt_token_ids=prompt_data) else: From bcada94e587bf04991febc3b23fb615882305be2 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 21 Sep 2024 11:51:59 +0800 Subject: [PATCH 08/15] Add type annotation --- vllm/entrypoints/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index fbeb781d3b0f..186173a8be08 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -433,7 +433,7 @@ def chat( # messages is List[...] list_of_messages = [messages] - outputs = [] + outputs: List[List[RequestOutput]]] = [] for msgs in list_of_messages: tokenizer = self.get_tokenizer() From ba593d5201df7e5b7ba962e95406ef048fdb2ab6 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 21 Sep 2024 11:55:41 +0800 Subject: [PATCH 09/15] Fix --- vllm/entrypoints/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 186173a8be08..89d35d5d028e 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -433,7 +433,7 @@ def chat( # messages is List[...] list_of_messages = [messages] - outputs: List[List[RequestOutput]]] = [] + outputs: List[List[RequestOutput]] = [] for msgs in list_of_messages: tokenizer = self.get_tokenizer() From 8a88e108189fe4ae852e7b73da78dd65b203ad9c Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Mon, 23 Sep 2024 09:50:38 -0700 Subject: [PATCH 10/15] update example --- examples/offline_inference_chat.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/examples/offline_inference_chat.py b/examples/offline_inference_chat.py index c2020724c72f..1d5451bab161 100644 --- a/examples/offline_inference_chat.py +++ b/examples/offline_inference_chat.py @@ -39,6 +39,31 @@ def print_outputs(outputs): use_tqdm=False) print_outputs(outputs) +# You can run batch inference with llm.chat API +conversation = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": "Hello" + }, + { + "role": "assistant", + "content": "Hello! How can I assist you today?" + }, + { + "role": "user", + "content": "Write an essay about the importance of higher education.", + }, +] +conversations = [conversation for _ in range(10)] +outputs = llm.chat(messages=conversations, + sampling_params=sampling_params, + use_tqdm=False) +print_outputs(outputs) + # A chat template can be optionally supplied. # If not, the model will use its default chat template. From e87415b9ca95045267dc9412951a23d564363741 Mon Sep 17 00:00:00 2001 From: aandyw <37781802+aandyw@users.noreply.github.com> Date: Tue, 24 Sep 2024 01:00:30 -0400 Subject: [PATCH 11/15] added batch generate --- vllm/entrypoints/llm.py | 88 +++++++++++++++++++++-------------------- 1 file changed, 46 insertions(+), 42 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 31506d63cce3..1cb0c21fe71a 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -354,34 +354,6 @@ def generate( outputs = self._run_engine(use_tqdm=use_tqdm) return LLMEngine.validate_outputs(outputs, RequestOutput) - @overload - def chat( - self, - messages: List[ChatCompletionMessageParam], - sampling_params: Optional[Union[SamplingParams, - List[SamplingParams]]] = None, - use_tqdm: bool = True, - lora_request: Optional[LoRARequest] = None, - chat_template: Optional[str] = None, - add_generation_prompt: bool = True, - tools: Optional[List[Dict[str, Any]]] = None, - ) -> List[RequestOutput]: - ... - - @overload - def chat( - self, - messages: List[List[ChatCompletionMessageParam]], - sampling_params: Optional[Union[SamplingParams, - List[SamplingParams]]] = None, - use_tqdm: bool = True, - lora_request: Optional[LoRARequest] = None, - chat_template: Optional[str] = None, - add_generation_prompt: bool = True, - tools: Optional[List[Dict[str, Any]]] = None, - ) -> List[List[RequestOutput]]: - ... - def chat( self, messages: Union[List[ChatCompletionMessageParam], @@ -393,7 +365,7 @@ def chat( chat_template: Optional[str] = None, add_generation_prompt: bool = True, tools: Optional[List[Dict[str, Any]]] = None, - ) -> Union[List[RequestOutput], List[List[RequestOutput]]]: + ) -> List[RequestOutput]: """ Generate responses for a chat conversation. @@ -421,9 +393,8 @@ def chat( to each message. Returns: - A list of lists or single list of ``RequestOutput`` objects - containing the generated responses in the same order as the input - conversations and messages. + A list of ``RequestOutput`` objects containing the generated + responses in the same order as the input messages. """ list_of_messages: List[List[ChatCompletionMessageParam]] @@ -435,7 +406,7 @@ def chat( # messages is List[...] list_of_messages = [messages] - outputs: List[List[RequestOutput]] = [] + prompts: List[List[PromptType]] = [] for msgs in list_of_messages: tokenizer = self.get_tokenizer() @@ -471,16 +442,16 @@ def chat( if mm_data is not None: prompt["multi_modal_data"] = mm_data - out = self.generate( - prompt, - sampling_params=sampling_params, - use_tqdm=use_tqdm, - lora_request=lora_request, - ) - outputs.append(out) + prompts.append(prompt) + + outputs = self.generate( + prompts, + sampling_params=sampling_params, + use_tqdm=use_tqdm, + lora_request=lora_request, + ) - # When messages is List[...], return a single list - return outputs if len(outputs) > 1 else outputs[0] + return outputs @overload # LEGACY: single (prompt + optional token ids) def encode( @@ -800,3 +771,36 @@ def _is_encoder_decoder_model(self): def _is_embedding_model(self): return self.llm_engine.is_embedding_model() + + +if __name__ == "__main__": + llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct") + + prompt1 = "Explain the concept of entropy." + prompt2 = "Explain what among us is." + + conversation1 = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": prompt1 + }, + ] + + conversation2 = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": prompt2 + }, + ] + + messages = [conversation1, conversation2] + + outputs = llm.chat(messages) \ No newline at end of file From b7ab2948b50f04a968a6902d3d5a4b7ea7fdddb6 Mon Sep 17 00:00:00 2001 From: Andy <37781802+aandyw@users.noreply.github.com> Date: Tue, 24 Sep 2024 01:34:58 -0400 Subject: [PATCH 12/15] Update llm.py --- vllm/entrypoints/llm.py | 33 --------------------------------- 1 file changed, 33 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 1cb0c21fe71a..e2ed268b8a47 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -771,36 +771,3 @@ def _is_encoder_decoder_model(self): def _is_embedding_model(self): return self.llm_engine.is_embedding_model() - - -if __name__ == "__main__": - llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct") - - prompt1 = "Explain the concept of entropy." - prompt2 = "Explain what among us is." - - conversation1 = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": prompt1 - }, - ] - - conversation2 = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": prompt2 - }, - ] - - messages = [conversation1, conversation2] - - outputs = llm.chat(messages) \ No newline at end of file From bdcf223a7cc645ee428552ff190fe103cb40335a Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Mon, 23 Sep 2024 22:42:52 -0700 Subject: [PATCH 13/15] cleanup --- vllm/entrypoints/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index e2ed268b8a47..ac8c927f22c3 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -406,7 +406,7 @@ def chat( # messages is List[...] list_of_messages = [messages] - prompts: List[List[PromptType]] = [] + prompts: List[PromptType] = [] for msgs in list_of_messages: tokenizer = self.get_tokenizer() From e9510fe45508634ac72308b48aaee27a9517904b Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Mon, 23 Sep 2024 23:46:18 -0700 Subject: [PATCH 14/15] typing --- vllm/entrypoints/llm.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index e6287036fc62..cd10eda8c212 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -535,7 +535,7 @@ def chat( # messages is List[...] list_of_messages = [messages] - prompts: List[PromptType] = [] + prompts: List[Union[TokensPrompt, TextPrompt]] = [] for msgs in list_of_messages: tokenizer = self.get_tokenizer() @@ -562,7 +562,7 @@ def chat( tools=tools, ) - prompt: PromptType + prompt: Union[TokensPrompt, TextPrompt] if is_list_of(prompt_data, int): prompt = TokensPrompt(prompt_token_ids=prompt_data) else: @@ -573,15 +573,13 @@ def chat( prompts.append(prompt) - outputs = self.generate( + return self.generate( prompts, sampling_params=sampling_params, use_tqdm=use_tqdm, lora_request=lora_request, ) - return outputs - @overload # LEGACY: single (prompt + optional token ids) def encode( self, From 049a6572f971ec06b4a4a8f3aa1d55ad3715f21b Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Mon, 23 Sep 2024 23:49:59 -0700 Subject: [PATCH 15/15] turn on tqdm --- examples/offline_inference_chat.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/offline_inference_chat.py b/examples/offline_inference_chat.py index 1d5451bab161..8814f4d7bef0 100644 --- a/examples/offline_inference_chat.py +++ b/examples/offline_inference_chat.py @@ -59,9 +59,11 @@ def print_outputs(outputs): }, ] conversations = [conversation for _ in range(10)] + +# We turn on tqdm progress bar to verify it's indeed running batch inference outputs = llm.chat(messages=conversations, sampling_params=sampling_params, - use_tqdm=False) + use_tqdm=True) print_outputs(outputs) # A chat template can be optionally supplied.