diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index fd28bf39e2d5..081122c8cb34 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -531,6 +531,7 @@ def beam_search( prompts: list[Union[TokensPrompt, TextPrompt]], params: BeamSearchParams, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + use_tqdm: bool = False, ) -> list[BeamSearchOutput]: """ Generate sequences using beam search. @@ -540,6 +541,7 @@ def beam_search( of token IDs. params: The beam search parameters. lora_request: LoRA request to use for generation, if any. + use_tqdm: Whether to use tqdm to display the progress bar. """ # TODO: how does beam search work together with length penalty, # frequency, penalty, and stopping criteria, etc.? @@ -602,7 +604,18 @@ def create_tokens_prompt_from_beam( **mm_kwargs, ), ) - for _ in range(max_tokens): + token_iter = range(max_tokens) + if use_tqdm: + token_iter = tqdm(token_iter, + desc="Beam search", + unit="token", + unit_scale=False) + logger.warning( + "The progress bar shows the upper bound on token steps and " + "may finish early due to stopping conditions. It does not " + "reflect instance-level progress.") + + for _ in token_iter: all_beams: list[BeamSearchSequence] = list( sum((instance.beams for instance in instances), [])) pos = [0] + list(