From e37b9db52c36f5e34d17a4e233c22f6f5b369023 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 5 Sep 2024 19:35:09 +0000 Subject: [PATCH 1/5] Multi-Step Chunked-Prefill Support --- csrc/prepare_inputs/advance_step.cu | 2 +- .../multi_step/test_correctness_async_llm.py | 9 ++ tests/multi_step/test_correctness_llm.py | 4 + vllm/attention/backends/flash_attn.py | 32 +++- vllm/attention/backends/flashinfer.py | 20 ++- vllm/config.py | 13 +- vllm/core/block/block_table.py | 13 +- vllm/core/block_manager_v1.py | 7 +- vllm/core/block_manager_v2.py | 5 +- vllm/core/embedding_model_block_manager.py | 4 +- vllm/core/interfaces.py | 4 +- vllm/core/scheduler.py | 127 +++++++++++---- vllm/engine/arg_utils.py | 10 +- vllm/engine/async_llm_engine.py | 9 +- vllm/engine/llm_engine.py | 130 +++++++++++++-- vllm/engine/output_processor/multi_step.py | 1 + vllm/model_executor/sampling_metadata.py | 49 ++++++ vllm/sequence.py | 46 +++++- vllm/worker/model_runner_base.py | 6 + vllm/worker/multi_step_model_runner.py | 148 +++++++++++++++--- vllm/worker/multi_step_worker.py | 5 +- 21 files changed, 540 insertions(+), 104 deletions(-) diff --git a/csrc/prepare_inputs/advance_step.cu b/csrc/prepare_inputs/advance_step.cu index 1f3f4710735e..195eb27dee74 100644 --- a/csrc/prepare_inputs/advance_step.cu +++ b/csrc/prepare_inputs/advance_step.cu @@ -52,7 +52,7 @@ __global__ void advance_step_flashattn_kernel( slot_mapping_ptr[cur_query_id] = slot_num; } -inline void verify_tensor(std::string const& name, torch::Tensor& t, +inline void verify_tensor(std::string const& name, torch::Tensor const& t, int64_t const size_0, int64_t const size_1, c10::ScalarType const type) { bool size_0_cond = true; diff --git a/tests/multi_step/test_correctness_async_llm.py b/tests/multi_step/test_correctness_async_llm.py index a75a671e57f7..615549f2134a 100644 --- a/tests/multi_step/test_correctness_async_llm.py +++ b/tests/multi_step/test_correctness_async_llm.py @@ -37,6 +37,7 @@ @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("is_async", [True]) @pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"]) +@pytest.mark.parametrize("enable_chunked_prefill", [True, False]) @pytest.mark.asyncio async def test_multi_step( example_prompts, @@ -49,6 +50,7 @@ async def test_multi_step( is_async: bool, num_logprobs: Optional[int], attention_backend: str, + enable_chunked_prefill: bool, monkeypatch, ) -> None: """Test vLLM engine with multi-step scheduling in an OpenAI-protocol @@ -74,6 +76,10 @@ async def test_multi_step( num_logprobs: corresponds to the `logprobs` argument to the OpenAI completions endpoint; `None` -> no logprobs """ + if enable_chunked_prefill and \ + (pp_size > 1 or attention_backend != "FLASH_ATTN"): + pytest.skip("Multi-step with Chunked-Prefill only supports" + "PP=1 and FLASH_ATTN backend") override_backend_env_variable(monkeypatch, attention_backend) @@ -93,6 +99,9 @@ async def test_multi_step( if eager_mode: ms_server_args.append("--enforce-eager") + if enable_chunked_prefill: + ms_server_args.append("--enable-chunked-prefill") + distributed_args = [ "--tensor-parallel-size", str(tp_size), diff --git a/tests/multi_step/test_correctness_llm.py b/tests/multi_step/test_correctness_llm.py index c5dc81cc2562..ff413e8e2da3 100644 --- a/tests/multi_step/test_correctness_llm.py +++ b/tests/multi_step/test_correctness_llm.py @@ -16,6 +16,7 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("tp_size", [1]) +@pytest.mark.parametrize("enable_chunked_prefill", [False, True]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [True]) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @@ -28,6 +29,7 @@ def test_multi_step_llm( model: str, dtype: str, tp_size: int, + enable_chunked_prefill: bool, max_tokens: int, enforce_eager: int, num_scheduler_steps: int, @@ -51,6 +53,7 @@ def test_multi_step_llm( model: model under test (same for single- and multi-step engines) dtype: tensor datatype for engine to utilize tp_size: degree of tensor-parallelism + enable_chunked_prefill: chunked-prefill on/off max_tokens: the maximum number of tokens to generate enforce_eager num_scheduler_steps: for multi-step scheduling, GPU-side steps per @@ -73,6 +76,7 @@ def test_multi_step_llm( gpu_memory_utilization=0.7, tensor_parallel_size=tp_size, use_v2_block_manager=True, + enable_chunked_prefill=enable_chunked_prefill, num_scheduler_steps=num_scheduler_steps, ) as vllm_model: vllm_outputs = (vllm_model.generate_greedy(prompts, max_tokens) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 084e8113cd42..3a7a8cd62ba5 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -342,9 +342,13 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: ) return self._cached_decode_metadata - def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", sampled_token_ids: Optional[torch.Tensor], - block_size: int, num_seqs: int, num_queries: int): + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): """ Update metadata in-place to advance one decode step. """ @@ -355,6 +359,23 @@ def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", assert num_seqs > num_queries assert self.use_cuda_graph + if turn_prefills_into_decodes: + # When Mutli-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens is not None + assert self.max_decode_seq_len == max(self.seq_lens) + assert self.num_prefills == 0 assert self.num_prefill_tokens == 0 assert self.num_decode_tokens == num_seqs @@ -366,7 +387,6 @@ def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", assert self.seq_lens_tensor.shape == (num_seqs, ) assert self.max_query_len == 1 assert self.max_prefill_seq_len == 0 - assert self.max_decode_seq_len == max(self.seq_lens) assert self.query_start_loc is not None assert self.query_start_loc.shape == (num_queries + 1, ) @@ -704,8 +724,10 @@ def forward( num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa + assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 3a602fbfbbc0..fa2f70dde9f2 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -410,18 +410,22 @@ def decode_metadata(self) -> Optional["FlashInferMetadata"]: return self - def advance_step( - self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - ): + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): """ Update metadata in-place to advance one decode step. """ + assert not turn_prefills_into_decodes, \ + ("Chunked prefill is not supported with flashinfer yet." + "turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill " + "specific parameter.") + assert num_seqs > 0 assert num_queries > 0 assert model_input.attn_metadata is not None diff --git a/vllm/config.py b/vllm/config.py index 108badf150c8..3139c5a08bfb 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -983,9 +983,16 @@ def __init__(self, policy: str = "fcfs") -> None: if max_num_batched_tokens is None: if enable_chunked_prefill: - # It is the values that have the best balance between ITL - # and TTFT on A100. Note it is not optimized for throughput. - max_num_batched_tokens = 512 + if num_scheduler_steps > 1: + # Multi-step Chunked-Prefill doesn't allow prompt-chunking + # for now. Have max_num_batched_tokens set to max_model_len + # so we don't reject sequences on account of a short + # max_num_batched_tokens. + max_num_batched_tokens = max(max_model_len, 2048) + else: + # It is the values that have the best balance between ITL + # and TTFT on A100. Note it is not optimized for throughput. + max_num_batched_tokens = 512 else: # If max_model_len is too short, use 2048 as the default value # for higher throughput. diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index c002dd1397f9..a9f4bd871dfd 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -55,9 +55,12 @@ def __init__( self._num_full_slots = self._get_num_token_ids() @staticmethod - def get_num_required_blocks(token_ids: List[int], block_size: int) -> int: + def get_num_required_blocks(token_ids: List[int], + block_size: int, + num_lookahead_slots: int = 0) -> int: """Calculates the minimum number of blocks required to store a given - sequence of token IDs. + sequence of token IDs along with any look-ahead slots that may be + required (like in multi-step + chunked-prefill). This assumes worst-case scenario, where every block requires a new allocation (e.g. ignoring prefix caching). @@ -66,12 +69,14 @@ def get_num_required_blocks(token_ids: List[int], block_size: int) -> int: token_ids (List[int]): The sequence of token IDs to be stored. block_size (int): The maximum number of tokens that can be stored in a single block. + num_lookahead_slots (int): look-ahead slots that the sequence may + require. Returns: int: The minimum number of blocks required to store the given - sequence of token IDs. + sequence of token IDs along with any required look-ahead slots. """ - return cdiv(len(token_ids), block_size) + return cdiv(len(token_ids) + num_lookahead_slots, block_size) def allocate(self, token_ids: List[int], diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 24ab9eb66194..a1f96707a6b5 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -281,10 +281,15 @@ def __init__( def _get_seq_num_required_blocks(self, seq: Optional[Sequence]) -> int: return 0 if seq is None else seq.n_blocks - def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: + def can_allocate(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. + assert (num_lookahead_slots == 0 + ), "lookahead allocation not supported in BlockSpaceManagerV1" + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) self_num_required_blocks = self._get_seq_num_required_blocks( diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 54818c7e3e9a..bb78b1e1c913 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -107,7 +107,9 @@ def __init__( self._last_access_blocks_tracker = LastAccessBlocksTracker( self.block_allocator) - def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: + def can_allocate(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. @@ -117,6 +119,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: num_required_blocks = BlockTable.get_num_required_blocks( seq.get_token_ids(), block_size=self.block_size, + num_lookahead_slots=num_lookahead_slots, ) if seq_group.is_encoder_decoder(): diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/embedding_model_block_manager.py index c47d7d8dfb07..476e043ecc52 100644 --- a/vllm/core/embedding_model_block_manager.py +++ b/vllm/core/embedding_model_block_manager.py @@ -21,7 +21,9 @@ def __init__( ) -> None: pass - def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: + def can_allocate(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> AllocStatus: # Always return OK for dummy purposes return AllocStatus.OK diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 96f8dd851b2f..634671158730 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -44,7 +44,9 @@ def get_block_space_manager_class(version: str): raise ValueError(f"Unknown version {version=}") @abstractmethod - def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: + def can_allocate(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> AllocStatus: pass @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 873decff37c1..44a1cbdab16e 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -522,7 +522,7 @@ def _schedule_running( ret.swapped_out.clear() ret.num_lookahead_slots = self._get_num_lookahead_slots( - is_prefill=False) + is_prefill=False, enable_chunking=enable_chunking) ret.decode_seq_groups_list.clear() ret.prefill_seq_groups_list.clear() @@ -561,7 +561,7 @@ def _schedule_running( # NOTE(woosuk): Preemption happens only when there is no available # slot to keep all the sequence groups in the RUNNING state. - while not self._can_append_slots(seq_group): + while not self._can_append_slots(seq_group, enable_chunking): budget.subtract_num_batched_tokens(seq_group.request_id, num_running_tokens) num_running_seqs = seq_group.get_max_num_running_seqs() @@ -611,7 +611,7 @@ def _schedule_running( if not cont_loop: break else: - self._append_slots(seq_group, blocks_to_copy) + self._append_slots(seq_group, blocks_to_copy, enable_chunking) is_prefill = seq_group.is_prefill() scheduled_seq_group: ScheduledSequenceGroup = \ @@ -684,7 +684,8 @@ def _schedule_swapped( # If the sequence group cannot be swapped in, stop. is_prefill = seq_group.is_prefill() alloc_status = self.block_manager.can_swap_in( - seq_group, self._get_num_lookahead_slots(is_prefill)) + seq_group, + self._get_num_lookahead_slots(is_prefill, enable_chunking)) if alloc_status == AllocStatus.LATER: break elif alloc_status == AllocStatus.NEVER: @@ -727,7 +728,7 @@ def _schedule_swapped( curr_loras.add(lora_int_id) swapped_queue.popleft() self._swap_in(seq_group, blocks_to_swap_in) - self._append_slots(seq_group, blocks_to_copy) + self._append_slots(seq_group, blocks_to_copy, enable_chunking) is_prefill = seq_group.is_prefill() if is_prefill: prefill_seq_groups.append( @@ -747,12 +748,13 @@ def _schedule_swapped( blocks_to_swap_in=blocks_to_swap_in, blocks_to_copy=blocks_to_copy, num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=False), + is_prefill=False, enable_chunking=enable_chunking), infeasible_seq_groups=infeasible_seq_groups, ) def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: - if self.scheduler_config.chunked_prefill_enabled: + if self.scheduler_config.chunked_prefill_enabled and \ + not self.scheduler_config.is_multi_step: prompt_limit = self.scheduler_config.max_model_len else: prompt_limit = min(self.scheduler_config.max_model_len, @@ -899,13 +901,20 @@ def _schedule_prefills( waiting_queue.popleft() continue + num_lookahead_slots: int = 0 + if self.scheduler_config.is_multi_step and enable_chunking: + num_lookahead_slots = self._get_num_lookahead_slots( + True, enable_chunking) + # If the sequence group cannot be allocated, stop. - can_allocate = self.block_manager.can_allocate(seq_group) + can_allocate = self.block_manager.can_allocate( + seq_group, num_lookahead_slots=num_lookahead_slots) if can_allocate == AllocStatus.LATER: break elif can_allocate == AllocStatus.NEVER: logger.warning( - "Input prompt (%d tokens) is too long" + "Input prompt (%d tokens) + lookahead slots " + "({num_lookahead_slots}) is too long" " and exceeds the capacity of block_manager", num_new_tokens) for seq in waiting_seqs: @@ -939,9 +948,24 @@ def _schedule_prefills( curr_loras.add(lora_int_id) waiting_queue.popleft() self._allocate_and_set_running(seq_group) - seq_group.init_multi_step( - num_scheduler_steps=self._get_num_lookahead_slots( - is_prefill=True) + 1) + + if enable_chunking and self.scheduler_config.is_multi_step: + blocks_to_copy: List[Tuple[int, int]] = [] + # init_multi_step_from_lookahead_slots happens in append_slots + self._append_slots(seq_group, blocks_to_copy, enable_chunking) + # This assert will trip when a copy-on-write happens. This is + # not a concern as the very first sequence-group block + # allocation happens above. Still, we have the assert to + # catch any edge-cases. + assert not blocks_to_copy + else: + seq_group.init_multi_step_from_lookahead_slots( + num_lookahead_slots, + num_scheduler_steps=self.scheduler_config. + num_scheduler_steps, + is_multi_step=self.scheduler_config.is_multi_step, + enable_chunking=enable_chunking) + seq_groups.append( ScheduledSequenceGroup(seq_group=seq_group, token_chunk_size=num_new_tokens)) @@ -956,7 +980,8 @@ def _schedule_prefills( return SchedulerPrefillOutputs( seq_groups=seq_groups, ignored_seq_groups=ignored_seq_groups, - num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True)) + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill=True, enable_chunking=enable_chunking)) def _schedule_default(self) -> SchedulerOutputs: """Schedule queued requests. @@ -1153,7 +1178,8 @@ def _schedule(self) -> SchedulerOutputs: else: return self._schedule_default() - def _can_append_slots(self, seq_group: SequenceGroup) -> bool: + def _can_append_slots(self, seq_group: SequenceGroup, + enable_chunking: bool) -> bool: """Determine whether or not we have enough space in the KV cache to continue generation of the sequence group. """ @@ -1164,12 +1190,16 @@ def _can_append_slots(self, seq_group: SequenceGroup) -> bool: self.artificial_preempt_cnt -= 1 return False - # Appending slots only occurs in decoding. - is_prefill = False + is_prefill = seq_group.is_prefill() + + # Appending prefill slots only happens chunked prefill is enabled. + assert self.scheduler_config.chunked_prefill_enabled or \ + not is_prefill return self.block_manager.can_append_slots( seq_group=seq_group, - num_lookahead_slots=self._get_num_lookahead_slots(is_prefill), + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill, enable_chunking), ) def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: @@ -1186,7 +1216,7 @@ def schedule( # such as self.running, self.swapped, and self.waiting. scheduler_start_time = time.perf_counter() - scheduler_outputs = self._schedule() + scheduler_outputs: SchedulerOutputs = self._schedule() now = time.time() if not self.cache_config.enable_prefix_caching: @@ -1383,11 +1413,10 @@ def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): seq.status = SequenceStatus.RUNNING - def _append_slots( - self, - seq_group: SequenceGroup, - blocks_to_copy: List[Tuple[int, int]], - ) -> None: + def _append_slots(self, + seq_group: SequenceGroup, + blocks_to_copy: List[Tuple[int, int]], + enable_chunking: bool = False) -> None: """Appends new slots to the sequences in the given sequence group. Args: @@ -1398,11 +1427,25 @@ def _append_slots( int is the destination block index. This list is updated with the new source and destination block indices for the appended slots. + enable_chunking (bool): True if chunked prefill is enabled. """ - num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False) - seq_group.init_multi_step(num_scheduler_steps=num_lookahead_slots + 1) - - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + is_prefill: bool = seq_group.is_prefill() + num_lookahead_slots: int = self._get_num_lookahead_slots( + is_prefill, enable_chunking) + + seq_group.init_multi_step_from_lookahead_slots( + num_lookahead_slots, + num_scheduler_steps=self.scheduler_config.num_scheduler_steps, + is_multi_step=self.scheduler_config.is_multi_step, + enable_chunking=enable_chunking) + + seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING + if self.scheduler_config.is_multi_step and enable_chunking: + # In multi-step chunked-prefill any sequence type can have + # slots appended. + seq_status = None + + for seq in seq_group.get_seqs(status=seq_status): cows = self.block_manager.append_slots(seq, num_lookahead_slots) if len(cows) > 0: blocks_to_copy.extend(cows) @@ -1513,16 +1556,32 @@ def _passed_delay(self, now: float) -> bool: passed_delay = True return passed_delay - def _get_num_lookahead_slots(self, is_prefill: bool) -> int: + def _get_num_lookahead_slots(self, is_prefill: bool, + enable_chunking: bool) -> int: """The number of slots to allocate per sequence per step, beyond known token ids. Speculative decoding uses these slots to store KV activations of tokens which may or may not be accepted. Speculative decoding does not yet support prefill, so we do not perform lookahead allocation for prefill. + + When chunking is enabled with multi-step, we allocate lookahead slots + for the prefills for when the prefills turn into decodes in the first + step. """ if is_prefill: - return 0 + if self.scheduler_config.is_multi_step and enable_chunking: + # num_lookahead_slots was introduced in the context of decodes, + # in Speculative Decoding. + # When the num_scheduler_steps is 8, say, then the + # num_lookahead_slots is 7. Meaning, we are doing a 1-step of + # decode anyways and we wish to do 7 more. + # + # "lookaheads" for prefills, is introduced in support for + # Chunked-Prefill in Multi-Step. + return self.scheduler_config.num_lookahead_slots + 1 + else: + return 0 return self.scheduler_config.num_lookahead_slots @@ -1565,6 +1624,16 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, if remaining_token_budget < num_new_tokens: num_new_tokens = (remaining_token_budget // block_size) * block_size + elif self.scheduler_config.is_multi_step: + if num_new_tokens > self._get_prompt_limit(seq_group): + # If the seq_group is in prompt-stage, pass the + # num_new_tokens as-is so the caller can ignore + # the sequence. + pass + else: + num_new_tokens = 0 \ + if num_new_tokens > remaining_token_budget \ + else num_new_tokens else: num_new_tokens = min(num_new_tokens, remaining_token_budget) return num_new_tokens diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0d4559e37742..0efb0cbbf8be 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -980,9 +980,13 @@ def create_engine_config(self) -> EngineConfig: if speculative_config is not None: raise ValueError("Speculative decoding is not supported with " "multi-step (--num-scheduler-steps > 1)") - if self.enable_chunked_prefill: - raise ValueError("Chunked prefill is not supported with " - "multi-step (--num-scheduler-steps > 1)") + if self.enable_chunked_prefill and self.enable_prefix_caching: + raise ValueError("Multi-Step is not supported with " + "both Chunked-Prefill and Prefix-Caching " + "enabled together.") + if self.enable_chunked_prefill and self.pipeline_parallel_size > 1: + raise ValueError("Multi-Step Chunked-Prefill is not supported " + "for pipeline-parallel-size > 1") # make sure num_lookahead_slots is set the higher value depending on # if we are using speculative decoding or multi-step diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 54c5af2fe366..3361fdefc960 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -363,11 +363,18 @@ async def step_async( self.cached_scheduler_outputs[ virtual_engine] = SchedulerOutputState() + # is_first_step_output is True only when the num_steps of all + # the sequences are 1. When the num_steps > 1, + # multi_step_model_runner does the first-step output append. + is_first_step_output: bool = False if not seq_group_metadata_list \ + else seq_group_metadata_list[0].state.num_steps == 1 + ctx.append_output(outputs=outputs, seq_group_metadata_list=seq_group_metadata_list, scheduler_outputs=scheduler_outputs, is_async=allow_async_output_proc, - is_last_step=True) + is_last_step=True, + is_first_step_output=is_first_step_output) if outputs and allow_async_output_proc: assert len( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 487255cb6b59..4d453245b9d0 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -90,6 +90,12 @@ class OutputData(NamedTuple): scheduler_outputs: SchedulerOutputs is_async: bool is_last_step: bool + # Indicates if this output is from the first step of the + # multi-step. When multi-step is disabled, this is always + # set to True. + # is_first_step_output is invalid when `outputs` has + # outputs from multiple steps. + is_first_step_output: Optional[bool] skip: List[int] @@ -108,13 +114,15 @@ def __init__(self, multi_step_stream_outputs: bool = False): def append_output(self, outputs: List[SamplerOutput], seq_group_metadata_list: List[SequenceGroupMetadata], scheduler_outputs: SchedulerOutputs, is_async: bool, - is_last_step: bool): + is_last_step: bool, + is_first_step_output: Optional[bool]): self.output_queue.append( OutputData(outputs=outputs, seq_group_metadata_list=seq_group_metadata_list, scheduler_outputs=scheduler_outputs, is_async=is_async, is_last_step=is_last_step, + is_first_step_output=is_first_step_output, skip=[])) @@ -237,9 +245,10 @@ def __init__( "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " - "num_scheduler_steps=%d, multi_step_stream_outputs=%s, " - "enable_prefix_caching=%s, use_async_output_proc=%s, " - "use_cached_outputs=%s, mm_processor_kwargs=%s)", + "num_scheduler_steps=%d, chunked_prefill_enabled=%s " + "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " + "use_async_output_proc=%s, use_cached_outputs=%s, " + "mm_processor_kwargs=%s)", VLLM_VERSION, model_config.model, speculative_config, @@ -270,6 +279,7 @@ def __init__( model_config.served_model_name, scheduler_config.use_v2_block_manager, scheduler_config.num_scheduler_steps, + scheduler_config.chunked_prefill_enabled, scheduler_config.multi_step_stream_outputs, cache_config.enable_prefix_caching, model_config.use_async_output_proc, @@ -957,8 +967,66 @@ def _process_model_outputs(self, ctx: The virtual engine context to work on request_id: If provided, then only this request is going to be processed - """ + + def update_prefill_num_computed_tokens( + seq_group: SequenceGroup, + seq_group_meta: SequenceGroupMetadata, num_outputs: int, + is_first_step_output: Optional[bool]) -> None: + """ + seq_group: SequenceGroup - A prefill seq_group + seq_group_meta: SequenceGroupMetadata - Metadata of the given + prefill seq_group + num_outputs: int - number of output tokens being processed for the + given seq_group + is_first_step_output: Optional[bool] - + If multi-step is enabled and num_outputs is 1, this value + indicates if this outputs belongs to the first step in the + multi-step. + If multi-step is enabled and num_outputs > 1, this value + must be None, as num_outputs > 1 indicates that outputs from + all the steps in multi-step are submitted in a single burst. + When multi-step is disabled, this value is always True. + + When multi-step and chunked-prefill are enabled together, the + prefill sequence scheduled for multi-step execution turn into + decodes in the first step itself. This function accounts + for that conversion. + """ + + assert seq_group_meta.is_prompt + + token_chunk_size = seq_group_meta.token_chunk_size + + if num_outputs == 1: + assert is_first_step_output is not None + + if seq_group_meta.state.num_steps == 1: + assert is_first_step_output is True + seq_group.update_num_computed_tokens(token_chunk_size) + return + + # multi-step prefill is only supported when multi-step is + # enabled with chunked prefill + assert self.scheduler_config.is_multi_step and \ + self.scheduler_config.chunked_prefill_enabled + if is_first_step_output is True: + # This sequence is a prompt during the first step only. + seq_group.update_num_computed_tokens(token_chunk_size) + return + + assert is_first_step_output is None + + # multi-step prefill is only supported when multi-step is + # enabled with chunked prefill. Outputs from all the steps are + # submitted in a single burst. + assert self.scheduler_config.is_multi_step and \ + self.scheduler_config.chunked_prefill_enabled + assert num_outputs == seq_group_meta.state.num_steps, \ + f"#outputs {len(outputs)} - num steps {seq_group_meta.state.num_steps}" #noqa + # This sequence is a prompt during the first step only. + seq_group.update_num_computed_tokens(token_chunk_size) + now = time.time() if len(ctx.output_queue) == 0: @@ -969,20 +1037,27 @@ def _process_model_outputs(self, # When we process only one request, no pop is required # (since later we will process all of the rest) (outputs, seq_group_metadata_list, scheduler_outputs, is_async, - is_last_step, skip) = ctx.output_queue[0] + is_last_step, is_first_step_output, skip) = ctx.output_queue[0] else: (outputs, seq_group_metadata_list, scheduler_outputs, is_async, - is_last_step, skip) = ctx.output_queue.popleft() + is_last_step, is_first_step_output, + skip) = ctx.output_queue.popleft() # Sanity check assert len(seq_group_metadata_list) == len( scheduler_outputs.scheduled_seq_groups) - # Organize outputs by [step][sequence group] instead of - # [sequence group][step]. - if len(outputs) > 1: + has_multiple_outputs: bool = len(outputs) > 1 + if has_multiple_outputs: + assert self.scheduler_config.is_multi_step or \ + self.speculative_config + # Organize outputs by [step][sequence group] instead of + # [sequence group][step]. outputs_by_sequence_group = create_output_by_sequence_group( outputs, num_seq_groups=len(seq_group_metadata_list)) + # We have outputs for multiple steps submitted in a single burst, + # so invalidate is_first_step_output. + is_first_step_output = None else: outputs_by_sequence_group = outputs @@ -1018,14 +1093,17 @@ def _process_model_outputs(self, finished_before.append(i) continue - if len(outputs) > 1: + if has_multiple_outputs: output = outputs_by_sequence_group[i] else: output = [outputs_by_sequence_group[0][i]] - if not is_async: - seq_group.update_num_computed_tokens( - scheduled_seq_group.token_chunk_size) + if not is_async and seq_group_meta.is_prompt: + # Updates for all decodes happen when we actually append the + # token ids to the seq in process_outputs. + update_prefill_num_computed_tokens(seq_group, seq_group_meta, + len(output), + is_first_step_output) if outputs: for o in outputs: @@ -1159,8 +1237,18 @@ def _advance_to_next_step( if seq_group.is_finished(): continue - seq_group.update_num_computed_tokens( - seq_group_metadata.token_chunk_size) + if seq_group_metadata.is_prompt: + if self.scheduler_config.is_multi_step and \ + self.scheduler_config.chunked_prefill_enabled: + # Prompts are scheduled in multi-step only when + # chunking is enabled. These prompts turn into + # decodes after the very first step. Therefore, + # we skip the update to the num_computed_tokens + # here. + pass + else: + seq_group.update_num_computed_tokens( + seq_group_metadata.token_chunk_size) if seq_group_metadata.do_sample: assert len(sequence_group_outputs.samples) == 1, ( @@ -1172,6 +1260,7 @@ def _advance_to_next_step( assert len(seq_group.seqs) == 1 seq = seq_group.seqs[0] seq.append_token_id(sample.output_token, sample.logprobs) + seq_group.update_num_computed_tokens(1) def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. @@ -1324,12 +1413,19 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: if self.scheduler_config.is_multi_step: self.cached_scheduler_outputs[0] = SchedulerOutputState() + # is_first_step_output is True only when the num_steps of all + # the sequences are 1. When the num_steps > 1, + # multi_step_model_runner does the first-step output append. + is_first_step_output: bool = False if not seq_group_metadata_list \ + else seq_group_metadata_list[0].state.num_steps == 1 + # Add results to the output_queue ctx.append_output(outputs=outputs, seq_group_metadata_list=seq_group_metadata_list, scheduler_outputs=scheduler_outputs, is_async=allow_async_output_proc, - is_last_step=True) + is_last_step=True, + is_first_step_output=is_first_step_output) if outputs and allow_async_output_proc: assert len(outputs) == 1, ( diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 31c2bbc8e712..cd5cfe5485f2 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -170,6 +170,7 @@ def _process_seq_outputs(self, seq: Sequence, token_id=output_token_id, logprobs=output_logprob, ) + seq.data.update_num_computed_tokens(1) self._process_decode_and_stop(seq, sampling_params) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 97d36d31f2b1..0ecc4f7157c7 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -134,6 +134,9 @@ def __init__( num_prompts: int, skip_sampler_cpu_output: bool = False, reuse_sampling_tensors: bool = False, + # Used when multi-step is enabled with chunked-prefill. Refer to + # the comment in prepare_multistep_tensors. + selected_token_indices_multistep: Optional[torch.Tensor] = None ) -> None: self.seq_groups = seq_groups self.selected_token_indices = selected_token_indices @@ -141,6 +144,52 @@ def __init__( self.num_prompts = num_prompts self.skip_sampler_cpu_output = skip_sampler_cpu_output self.reuse_sampling_tensors = reuse_sampling_tensors + self.selected_token_indices_multistep = selected_token_indices_multistep + + def prepare_multistep_tensors(self, num_queries: int, device: str, + pin_memory: bool): + """ + Invoked when Multi-Step is enabled with Chunked-Prefill. + When Multi-Step is enabled with Chunked-Prefill, the prompts and + decodes are scheduled together. + self.selected_token_indices is constructed for the first-step in + Multi-Step. However, the scheduled prompts, are fully processed + in the first-step and are processed as decodes in the rest of the steps. + This function prepares a "selected_token_indices" to be used + in the rest of the steps. + + Example: + Let 2 prompts and 2 decodes be scheduled together. Let the + num-tokens to process for the 2 prompts be 5 and 8 resply. + + In that case, self.sampled_token_indices will be, + [4, 12, 13, 14] as it is constructed for the first-step in + multi-step. + However, the prompts turns to decodes after the first-step + and the num-tokens for the previously-prompt sequences will + be 1 and 1 as they are decodes now. The self.sampled_token_indices + must be updated to [0,1,2,3]. + prepare_multistep_tensors prepares the "selected_token_indices" + to be used in steps 2-N. + """ + selected_token_indices_multistep = list(range(num_queries)) + self.selected_token_indices_multistep = \ + async_tensor_h2d(selected_token_indices_multistep, + dtype=torch.long, + target_device=device, + pin_memory=pin_memory) + + def advance_step(self): + """ + Invoked when Multi-Step and Chunked-Prefill are enabled together. + The prefills that may have been scheduled, are fully processed in + the very first step and have turned into decodes. + Updated selected_token_indices to reflect that. Please refer to + the prepare_multistep_tensors docstring for an example. + """ + if self.selected_token_indices_multistep is not None: + # Swap to account for Single Step Prompts becoming Decodes + self.selected_token_indices = self.selected_token_indices_multistep @staticmethod def prepare( diff --git a/vllm/sequence.py b/vllm/sequence.py index 49a198df045b..781bcedde2b5 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -743,10 +743,35 @@ def prompt_adapter_num_virtual_tokens(self) -> int: return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\ if self.prompt_adapter_request else 0 - def init_multi_step(self, num_scheduler_steps: int) -> None: - self.state.num_steps = num_scheduler_steps + def init_multi_step(self, num_steps: int) -> None: + self.state.num_steps = num_steps self.state.current_step = 0 + def init_multi_step_from_lookahead_slots(self, num_lookahead_slots: int, + num_scheduler_steps: int, + is_multi_step: bool, + enable_chunking: bool) -> None: + + if not is_multi_step: + self.init_multi_step(num_steps=num_scheduler_steps) + return + + # Multi-Step case + is_prefill = self.is_prefill() + + # The asserts below reflect the expectations of the current system. + if is_prefill and enable_chunking: + assert num_lookahead_slots == num_scheduler_steps + self.init_multi_step(num_steps=num_lookahead_slots) + else: + is_decode: bool = not is_prefill + # If it is a prefill, num_lookahead_slots must be 0 + assert num_lookahead_slots == 0 or is_decode + # If it is a decode, num_lookahead_slots + 1 must match + # the scheduler steps. + assert num_lookahead_slots + 1 == num_scheduler_steps or is_prefill + self.init_multi_step(num_steps=num_lookahead_slots + 1) + def get_last_latency(self, now: float) -> Optional[float]: """Sets the last token time for Request level timings.""" # If still in prefill phase, raise Error. @@ -1010,6 +1035,20 @@ def prompt_adapter_num_virtual_tokens(self) -> int: return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \ if self.prompt_adapter_request else 0 + # Multi-Step Chunked-Prefill property + @property + def is_single_step_prompt(self) -> bool: + # do_sample is true, only when the token_chunk_size matches the + # num_uncomputed_tokens of the sequence. This indicates that + # the prompt will finish processing in a single `execute_model` + # step. + return self.is_prompt and self.do_sample + + def get_first_seq_id(self) -> int: + # This is an efficient way of fetching the seq_id when + # we know this SequenceGroup has only one sequence. + return next(iter(self.seq_data)) + def apply_delta(self, sequence_group_metadata_delta: SequenceGroupMetadataDelta): for id, delta in sequence_group_metadata_delta.seq_data_delta.items(): @@ -1022,7 +1061,8 @@ def apply_delta(self, def finish_step(self) -> None: assert self.state is not None - assert self.state.current_step < self.state.num_steps + assert self.state.current_step < self.state.num_steps, \ + f"current step {self.state.current_step}, num_steps {self.state.num_steps}" # noqa self.state.current_step += 1 diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 86883cf15244..1bb6a848390d 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -64,6 +64,8 @@ def _init_sampling_metadata_from_tensor_dict( # type: ignore from vllm.model_executor import SamplingMetadata selected_token_indices = tensor_dict.pop("selected_token_indices", None) + selected_token_indices_multistep = tensor_dict.pop( + "selected_token_indices_multistep", None) # An empty SamplingMetadata to signal that the worker should skip # sampling. if selected_token_indices is not None: @@ -72,6 +74,7 @@ def _init_sampling_metadata_from_tensor_dict( # type: ignore selected_token_indices=selected_token_indices, categorized_sample_indices=None, num_prompts=0, + selected_token_indices_multistep=selected_token_indices_multistep, ) return tensor_dict @@ -86,6 +89,9 @@ def _add_sampling_metadata_broadcastable_dict( if sampling_metadata is not None: tensor_dict["selected_token_indices"] = ( sampling_metadata.selected_token_indices) + if sampling_metadata.selected_token_indices_multistep is not None: + tensor_dict["selected_token_indices_multistep"] = ( + sampling_metadata.selected_token_indices_multistep) def _init_frozen_model_input_from_tensor_dict( diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index c7295f872f70..01f0fc2d569b 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -30,6 +30,14 @@ logger = init_logger(__name__) MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "rocm-flash-attn", "flashinfer"] +MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["flash-attn"] + +def _get_supported_attention_backends(chunked_prefill_enabled: bool) \ + -> List[str]: + if chunked_prefill_enabled: + return MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS + else: + return MULTI_STEP_ATTENTION_BACKENDS def seq_output_builder(): @@ -144,11 +152,13 @@ class StatefulModelInput(BroadcastableModelInput): is_multi_step: bool = True is_last_step: bool = False is_first_multi_step: bool = False + base_output_proc_callback: Optional[Callable] = None # ping-pong data structures for multi-step to wait on the previous step step_cuda_events: List[torch.cuda.Event] = field( default_factory=lambda: [torch.cuda.Event(blocking=True)] * 2) num_seqs: int = -1 num_queries: int = -1 + num_single_step_prefills: int = 0 def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: assert self.frozen_model_input is not None @@ -161,6 +171,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: 'is_first_multi_step': self.is_first_multi_step, 'num_seqs': self.num_seqs, 'num_queries': self.num_queries, + 'num_single_step_prefills': self.num_single_step_prefills, } tensor_dict.update(new_tensor_dict) return tensor_dict @@ -209,6 +220,50 @@ def add_sampler_output(self, sampled_token_ids=sampled_token_ids, pythonized=False)) + def maybe_advance_frozen_model_input(self): + """ + Advancing the datastructures of StatefulModelInput::frozen_model_input + is only required when prefills are scheduled with decodes to run in + multi-step. This advancement/correction is required to account for + the conversion of Prefills to Decodes after the first multi-step. + """ + if self.current_step != 1 or self.num_single_step_prefills == 0: + return + + assert self.frozen_model_input is not None + fmi = self.frozen_model_input + + # Truncate input_tokens + assert fmi.input_tokens is not None + assert fmi.input_tokens.shape[0] >= self.num_seqs + fmi_new_input_tokens: torch.Tensor = fmi.input_tokens[:self.num_seqs] + + # Update frozen_model_input::input_positons. + assert fmi.input_positions is not None + assert fmi.input_positions.shape[0] >= self.num_seqs + fmi_new_input_positions: torch.Tensor = fmi.input_positions[:self. + num_seqs] + + # Assert unsupported + assert fmi.lora_mapping is None + assert fmi.lora_requests is not None + assert len(fmi.lora_requests) == 0 + assert fmi.attn_metadata is not None + assert fmi.prompt_adapter_mapping is None + assert fmi.prompt_adapter_requests is not None + assert len(fmi.prompt_adapter_requests) == 0 + assert fmi.multi_modal_kwargs is not None + assert len(fmi.multi_modal_kwargs) == 0 + + self.frozen_model_input = dataclasses.replace( + self.frozen_model_input, + input_tokens=fmi_new_input_tokens, + input_positions=fmi_new_input_positions) + + if get_pp_group().is_last_rank: + assert self.frozen_model_input.sampling_metadata is not None + self.frozen_model_input.sampling_metadata.advance_step() + # MutableModelInputForGPUWithMultiStepMetadata is not subclass of # ModelInputForGPU but it wraps the actual input dataclass and adds multi-step @@ -220,6 +275,19 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs): super().__init__(*args, **kwargs) + # Check attention backend support. + supported_attention_backends: List[str] = \ + _get_supported_attention_backends( + self.scheduler_config.chunked_prefill_enabled) + if self.attn_backend.get_name() not in supported_attention_backends: + ms_config_str: str = "Multi-Step + Chunked-Prefill" \ + if self.scheduler_config.chunked_prefill_enabled \ + else "Multi-Step" + raise ValueError( + f"{ms_config_str} not supported for attention backend: " + f"{self.attn_backend.get_name()}. Set VLLM_ATTENTION_BACKEND " + f"to a value from {supported_attention_backends}.") + # uses the base model runner to execute the model and wraps it with # multi-step logic self._base_model_runner: GPUModelRunnerBase = base_model_runner @@ -248,14 +316,32 @@ def prepare_model_input( virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None ) -> StatefulModelInput: - frozen_model_input = self._base_model_runner.prepare_model_input( - seq_group_metadata_list, virtual_engine, finished_requests_ids) + frozen_model_input: ModelInputForGPUWithSamplingMetadata = \ + self._base_model_runner.prepare_model_input( + seq_group_metadata_list, + virtual_engine, + finished_requests_ids) + + assert frozen_model_input.query_lens is not None + assert frozen_model_input.seq_lens is not None + assert frozen_model_input.attn_metadata is not None + num_queries = len(frozen_model_input.query_lens) + num_seqs = len(frozen_model_input.seq_lens) + num_single_step_prefills = frozen_model_input.attn_metadata.num_prefills + + if get_pp_group().is_last_rank and num_single_step_prefills > 0: + assert frozen_model_input.sampling_metadata is not None + frozen_model_input.sampling_metadata.prepare_multistep_tensors( + num_queries=num_queries, + device=self.device, + pin_memory=self.pin_memory) model_input = StatefulModelInput( frozen_model_input=frozen_model_input, - num_seqs=len(frozen_model_input.seq_lens), - num_queries=len(frozen_model_input.query_lens), - ) + num_seqs=num_seqs, + num_queries=num_queries, + num_single_step_prefills=num_single_step_prefills) + return model_input def _async_process_outputs(self, model_input: StatefulModelInput, @@ -265,7 +351,7 @@ def _async_process_outputs(self, model_input: StatefulModelInput, output_proc_callback() cont = True - for model_output in model_input.cached_outputs: + for step_num, model_output in enumerate(model_input.cached_outputs): if not model_output.pythonized: model_output.maybe_pythonize(model_input, self._copy_stream, self.pinned_sampled_token_ids) @@ -276,7 +362,8 @@ def _async_process_outputs(self, model_input: StatefulModelInput, seq_group_metadata_list=ctx.seq_group_metadata_list, scheduler_outputs=ctx.scheduler_outputs, is_async=False, - is_last_step=False) + is_last_step=False, + is_first_step_output=step_num == 0) output_proc_callback() else: @@ -292,9 +379,8 @@ def _final_process_outputs(self, model_input: StatefulModelInput, has_async_callback = output_proc_callback is not None outputs = [] - for output_id in range(len(model_input.cached_outputs)): - output = model_input.cached_outputs[output_id] - is_last_step = output_id == len(model_input.cached_outputs) - 1 + for step_num, output in enumerate(model_input.cached_outputs): + is_last_step = step_num == len(model_input.cached_outputs) - 1 # For non-async case: # -- We simply add the outputs @@ -323,7 +409,8 @@ def _final_process_outputs(self, model_input: StatefulModelInput, seq_group_metadata_list, scheduler_outputs=ctx.scheduler_outputs, is_async=False, - is_last_step=False) + is_last_step=False, + is_first_step_output=step_num == 0) else: outputs.append(output.sampler_output) else: @@ -389,18 +476,27 @@ def execute_model( model_input = self._advance_step( model_input, model_input.cached_outputs[-1].sampler_output) - output_proc_callback = None + # frozen_model_input may have been updated + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + + if model_input.base_output_proc_callback is None: + assert frozen_model_input is not None + model_input.base_output_proc_callback = \ + frozen_model_input.async_callback + if frozen_model_input.async_callback is not None: - output_proc_callback = frozen_model_input.async_callback - assert output_proc_callback is not None + assert model_input.base_output_proc_callback is not None async_callback = functools.partial( self._async_process_outputs, model_input=model_input, - output_proc_callback=output_proc_callback) + output_proc_callback=model_input.base_output_proc_callback) - frozen_model_input = dataclasses.replace( # type: ignore + model_input.frozen_model_input = dataclasses.replace( # type: ignore model_input.frozen_model_input, async_callback=async_callback) + # Update the local instance + frozen_model_input = model_input.frozen_model_input assert frozen_model_input is not None # Execute the model @@ -455,8 +551,8 @@ def execute_model( # Pythonize the output and block if needed since it is the last step if model_input.is_last_step: - outputs = self._final_process_outputs(model_input, - output_proc_callback) + outputs = self._final_process_outputs( + model_input, model_input.base_output_proc_callback) self.pythonization_cache.reset() return outputs @@ -484,11 +580,13 @@ def _update_sampling_metadata(self, sampling_metadata, num_seqs, def _advance_step(self, model_input: StatefulModelInput, out: SamplerOutput) -> StatefulModelInput: - if self.attn_backend.get_name() not in MULTI_STEP_ATTENTION_BACKENDS: - raise ValueError( - f"Multi-step not supported for attention backend: " - f"{self.attn_backend.get_name()}. Set VLLM_ATTENTION_BACKEND " - f"to a value from {MULTI_STEP_ATTENTION_BACKENDS}.") + + model_input.maybe_advance_frozen_model_input() + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + assert frozen_model_input.input_tokens is not None + assert frozen_model_input.input_tokens.shape[0] == model_input.num_seqs + assert frozen_model_input.attn_metadata is not None sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids num_seqs = model_input.num_seqs @@ -498,13 +596,15 @@ def _advance_step(self, model_input: StatefulModelInput, attn_metadata = frozen_model_input.attn_metadata assert attn_metadata is not None + turn_prefills_into_decodes: bool = model_input.current_step == 1 and \ + model_input.num_single_step_prefills != 0 attn_metadata.advance_step( frozen_model_input, sampled_token_ids, self.block_size, num_seqs, num_queries, - ) + turn_prefills_into_decodes=turn_prefills_into_decodes) return model_input diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py index 562285f828cc..bf66f32d7d24 100644 --- a/vllm/worker/multi_step_worker.py +++ b/vllm/worker/multi_step_worker.py @@ -76,8 +76,9 @@ def _get_driver_input_and_broadcast( frozen_model_input = model_input.frozen_model_input assert frozen_model_input is not None assert frozen_model_input.attn_metadata is not None - # clear the cached decode metadata so that it can be recomputed on - # the workers + # clear the cached metadata so that it can be recomputed on + # the workers. + frozen_model_input.attn_metadata._cached_prefill_metadata = None frozen_model_input.attn_metadata._cached_decode_metadata = None model_input.is_first_multi_step = is_first_multi_step From 89e790c5945f3da6eba9d3ad7f20305f8c49f237 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 27 Sep 2024 15:29:41 +0000 Subject: [PATCH 2/5] review comments --- vllm/core/scheduler.py | 7 +++---- vllm/engine/llm_engine.py | 10 +++++----- vllm/model_executor/sampling_metadata.py | 2 +- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 44a1cbdab16e..bda2310b3708 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -913,10 +913,9 @@ def _schedule_prefills( break elif can_allocate == AllocStatus.NEVER: logger.warning( - "Input prompt (%d tokens) + lookahead slots " - "({num_lookahead_slots}) is too long" - " and exceeds the capacity of block_manager", - num_new_tokens) + "Input prompt (%d tokens) + lookahead slots (%d) is " + "too long and exceeds the capacity of block_manager", + num_new_tokens, num_lookahead_slots) for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 4d453245b9d0..19f88ac3e7c5 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -974,6 +974,11 @@ def update_prefill_num_computed_tokens( seq_group_meta: SequenceGroupMetadata, num_outputs: int, is_first_step_output: Optional[bool]) -> None: """ + When multi-step and chunked-prefill are enabled together, the + prefill sequence scheduled for multi-step execution turn into + decodes in the first step itself. This function accounts + for that conversion. + seq_group: SequenceGroup - A prefill seq_group seq_group_meta: SequenceGroupMetadata - Metadata of the given prefill seq_group @@ -987,11 +992,6 @@ def update_prefill_num_computed_tokens( must be None, as num_outputs > 1 indicates that outputs from all the steps in multi-step are submitted in a single burst. When multi-step is disabled, this value is always True. - - When multi-step and chunked-prefill are enabled together, the - prefill sequence scheduled for multi-step execution turn into - decodes in the first step itself. This function accounts - for that conversion. """ assert seq_group_meta.is_prompt diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 0ecc4f7157c7..bfcf815d5139 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -160,7 +160,7 @@ def prepare_multistep_tensors(self, num_queries: int, device: str, Example: Let 2 prompts and 2 decodes be scheduled together. Let the - num-tokens to process for the 2 prompts be 5 and 8 resply. + num-tokens to process for the 2 prompts be 5 and 8 respectively. In that case, self.sampled_token_indices will be, [4, 12, 13, 14] as it is constructed for the first-step in From d874ce6ef0957db92a0ad40ab2a601ecc04f0c9f Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 27 Sep 2024 16:37:40 +0000 Subject: [PATCH 3/5] Update selected_token_indices directly --- benchmarks/benchmark_throughput.py | 8 +++- vllm/model_executor/sampling_metadata.py | 49 ------------------------ vllm/worker/model_runner_base.py | 6 --- vllm/worker/multi_step_model_runner.py | 49 +++++++++++++++++------- 4 files changed, 41 insertions(+), 71 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 68b401d5bbbb..d5aa8a58996f 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -4,6 +4,7 @@ import random import time from typing import List, Optional, Tuple +import pickle as pkl import torch import uvloop @@ -126,7 +127,7 @@ def run_vllm( sampling_params.append( SamplingParams( n=n, - temperature=0.0 if use_beam_search else 1.0, + temperature=0.0 if use_beam_search else 0.0, top_p=1.0, use_beam_search=use_beam_search, ignore_eos=True, @@ -135,8 +136,11 @@ def run_vllm( if not use_new_beam_search_impl: start = time.perf_counter() - llm.generate(prompts, sampling_params, use_tqdm=True) + outputs = llm.generate(prompts, sampling_params, use_tqdm=True) end = time.perf_counter() + + with open("llm_engine_test.pkl", "wb+") as f: + pkl.dump(outputs, f) else: assert use_beam_search prompts = [prompt for prompt, _, _ in requests] diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index bfcf815d5139..97d36d31f2b1 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -134,9 +134,6 @@ def __init__( num_prompts: int, skip_sampler_cpu_output: bool = False, reuse_sampling_tensors: bool = False, - # Used when multi-step is enabled with chunked-prefill. Refer to - # the comment in prepare_multistep_tensors. - selected_token_indices_multistep: Optional[torch.Tensor] = None ) -> None: self.seq_groups = seq_groups self.selected_token_indices = selected_token_indices @@ -144,52 +141,6 @@ def __init__( self.num_prompts = num_prompts self.skip_sampler_cpu_output = skip_sampler_cpu_output self.reuse_sampling_tensors = reuse_sampling_tensors - self.selected_token_indices_multistep = selected_token_indices_multistep - - def prepare_multistep_tensors(self, num_queries: int, device: str, - pin_memory: bool): - """ - Invoked when Multi-Step is enabled with Chunked-Prefill. - When Multi-Step is enabled with Chunked-Prefill, the prompts and - decodes are scheduled together. - self.selected_token_indices is constructed for the first-step in - Multi-Step. However, the scheduled prompts, are fully processed - in the first-step and are processed as decodes in the rest of the steps. - This function prepares a "selected_token_indices" to be used - in the rest of the steps. - - Example: - Let 2 prompts and 2 decodes be scheduled together. Let the - num-tokens to process for the 2 prompts be 5 and 8 respectively. - - In that case, self.sampled_token_indices will be, - [4, 12, 13, 14] as it is constructed for the first-step in - multi-step. - However, the prompts turns to decodes after the first-step - and the num-tokens for the previously-prompt sequences will - be 1 and 1 as they are decodes now. The self.sampled_token_indices - must be updated to [0,1,2,3]. - prepare_multistep_tensors prepares the "selected_token_indices" - to be used in steps 2-N. - """ - selected_token_indices_multistep = list(range(num_queries)) - self.selected_token_indices_multistep = \ - async_tensor_h2d(selected_token_indices_multistep, - dtype=torch.long, - target_device=device, - pin_memory=pin_memory) - - def advance_step(self): - """ - Invoked when Multi-Step and Chunked-Prefill are enabled together. - The prefills that may have been scheduled, are fully processed in - the very first step and have turned into decodes. - Updated selected_token_indices to reflect that. Please refer to - the prepare_multistep_tensors docstring for an example. - """ - if self.selected_token_indices_multistep is not None: - # Swap to account for Single Step Prompts becoming Decodes - self.selected_token_indices = self.selected_token_indices_multistep @staticmethod def prepare( diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 1bb6a848390d..86883cf15244 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -64,8 +64,6 @@ def _init_sampling_metadata_from_tensor_dict( # type: ignore from vllm.model_executor import SamplingMetadata selected_token_indices = tensor_dict.pop("selected_token_indices", None) - selected_token_indices_multistep = tensor_dict.pop( - "selected_token_indices_multistep", None) # An empty SamplingMetadata to signal that the worker should skip # sampling. if selected_token_indices is not None: @@ -74,7 +72,6 @@ def _init_sampling_metadata_from_tensor_dict( # type: ignore selected_token_indices=selected_token_indices, categorized_sample_indices=None, num_prompts=0, - selected_token_indices_multistep=selected_token_indices_multistep, ) return tensor_dict @@ -89,9 +86,6 @@ def _add_sampling_metadata_broadcastable_dict( if sampling_metadata is not None: tensor_dict["selected_token_indices"] = ( sampling_metadata.selected_token_indices) - if sampling_metadata.selected_token_indices_multistep is not None: - tensor_dict["selected_token_indices_multistep"] = ( - sampling_metadata.selected_token_indices_multistep) def _init_frozen_model_input_from_tensor_dict( diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 01f0fc2d569b..8f8f5f468f8e 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -14,7 +14,7 @@ get_pythonized_sample_results) from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, Logprob, SequenceGroupMetadata, SequenceOutput) -from vllm.utils import PyObjectCache +from vllm.utils import PyObjectCache, async_tensor_h2d from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUWithSamplingMetadata) from vllm.worker.model_runner_base import ( @@ -220,7 +220,38 @@ def add_sampler_output(self, sampled_token_ids=sampled_token_ids, pythonized=False)) - def maybe_advance_frozen_model_input(self): + def maybe_advance_sampling_metadata(self, device: str, pin_memory: bool): + """ + sampling_metadata.selected_token_indices is constructed for the + first-step in Multi-Step. However, when chunked-prefill is enabled with + multi-step, the scheduled prompts are fully processed in the + first-step and are processed as decodes in the rest of the steps. + This function updates the sampling_metadata.selected_token_indices + to account for this conversion. + + Example: + Let 2 prompts and 2 decodes be scheduled together. Let the + num-tokens to process for the 2 prompts be 5 and 8 respectively. + + In that case, sampling_metadata.sampled_token_indices will be, + [4, 12, 13, 14] as it is constructed for the first-step in + multi-step. + However, the prompts turns to decodes after the first-step + and the num-tokens for the previously-prompt sequences will + be 1 and 1 as they are decodes now. The self.sampled_token_indices + must be updated to [0,1,2,3]. + """ + assert self.current_step == 1 and self.num_single_step_prefills > 0 + if not get_pp_group().is_last_rank: + return + + self.frozen_model_input.sampling_metadata.selected_token_indices = \ + async_tensor_h2d(list(range(self.num_queries)), + dtype=torch.long, + target_device=device, + pin_memory=pin_memory) + + def maybe_advance_frozen_model_input(self, device: str, pin_memory: bool): """ Advancing the datastructures of StatefulModelInput::frozen_model_input is only required when prefills are scheduled with decodes to run in @@ -260,10 +291,7 @@ def maybe_advance_frozen_model_input(self): input_tokens=fmi_new_input_tokens, input_positions=fmi_new_input_positions) - if get_pp_group().is_last_rank: - assert self.frozen_model_input.sampling_metadata is not None - self.frozen_model_input.sampling_metadata.advance_step() - + self.maybe_advance_sampling_metadata(device, pin_memory) # MutableModelInputForGPUWithMultiStepMetadata is not subclass of # ModelInputForGPU but it wraps the actual input dataclass and adds multi-step @@ -329,13 +357,6 @@ def prepare_model_input( num_seqs = len(frozen_model_input.seq_lens) num_single_step_prefills = frozen_model_input.attn_metadata.num_prefills - if get_pp_group().is_last_rank and num_single_step_prefills > 0: - assert frozen_model_input.sampling_metadata is not None - frozen_model_input.sampling_metadata.prepare_multistep_tensors( - num_queries=num_queries, - device=self.device, - pin_memory=self.pin_memory) - model_input = StatefulModelInput( frozen_model_input=frozen_model_input, num_seqs=num_seqs, @@ -581,7 +602,7 @@ def _update_sampling_metadata(self, sampling_metadata, num_seqs, def _advance_step(self, model_input: StatefulModelInput, out: SamplerOutput) -> StatefulModelInput: - model_input.maybe_advance_frozen_model_input() + model_input.maybe_advance_frozen_model_input(self.device, self.pin_memory) frozen_model_input = model_input.frozen_model_input assert frozen_model_input is not None assert frozen_model_input.input_tokens is not None From b4650b61b49a9a2e49eb80499f85b69350149806 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 27 Sep 2024 17:02:23 +0000 Subject: [PATCH 4/5] make can_append_slots assert stronger --- vllm/core/scheduler.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index bda2310b3708..66bbb86523fa 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1190,16 +1190,17 @@ def _can_append_slots(self, seq_group: SequenceGroup, return False is_prefill = seq_group.is_prefill() + num_lookahead_slots = self._get_num_lookahead_slots(is_prefill, + enable_chunking) - # Appending prefill slots only happens chunked prefill is enabled. - assert self.scheduler_config.chunked_prefill_enabled or \ - not is_prefill + if is_prefill and num_lookahead_slots > 0: + # Appending prefill slots only happens multi-step and + # chunked-prefill are enabled together. + assert self.scheduler_config.is_multi_step and enable_chunking return self.block_manager.can_append_slots( seq_group=seq_group, - num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill, enable_chunking), - ) + num_lookahead_slots=num_lookahead_slots) def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: no_beam_search = seq_group.sampling_params is None or ( From 7e8f66bb679565f6ee125c11679906d8ec10a07b Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 27 Sep 2024 17:04:25 +0000 Subject: [PATCH 5/5] format --- benchmarks/benchmark_throughput.py | 8 ++------ vllm/core/scheduler.py | 7 +++---- vllm/worker/multi_step_model_runner.py | 6 +++++- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index d5aa8a58996f..68b401d5bbbb 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -4,7 +4,6 @@ import random import time from typing import List, Optional, Tuple -import pickle as pkl import torch import uvloop @@ -127,7 +126,7 @@ def run_vllm( sampling_params.append( SamplingParams( n=n, - temperature=0.0 if use_beam_search else 0.0, + temperature=0.0 if use_beam_search else 1.0, top_p=1.0, use_beam_search=use_beam_search, ignore_eos=True, @@ -136,11 +135,8 @@ def run_vllm( if not use_new_beam_search_impl: start = time.perf_counter() - outputs = llm.generate(prompts, sampling_params, use_tqdm=True) + llm.generate(prompts, sampling_params, use_tqdm=True) end = time.perf_counter() - - with open("llm_engine_test.pkl", "wb+") as f: - pkl.dump(outputs, f) else: assert use_beam_search prompts = [prompt for prompt, _, _ in requests] diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 66bbb86523fa..5b7587d15084 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1190,8 +1190,8 @@ def _can_append_slots(self, seq_group: SequenceGroup, return False is_prefill = seq_group.is_prefill() - num_lookahead_slots = self._get_num_lookahead_slots(is_prefill, - enable_chunking) + num_lookahead_slots = self._get_num_lookahead_slots( + is_prefill, enable_chunking) if is_prefill and num_lookahead_slots > 0: # Appending prefill slots only happens multi-step and @@ -1199,8 +1199,7 @@ def _can_append_slots(self, seq_group: SequenceGroup, assert self.scheduler_config.is_multi_step and enable_chunking return self.block_manager.can_append_slots( - seq_group=seq_group, - num_lookahead_slots=num_lookahead_slots) + seq_group=seq_group, num_lookahead_slots=num_lookahead_slots) def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: no_beam_search = seq_group.sampling_params is None or ( diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 8f8f5f468f8e..4c57a37c8787 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -245,6 +245,8 @@ def maybe_advance_sampling_metadata(self, device: str, pin_memory: bool): if not get_pp_group().is_last_rank: return + assert self.frozen_model_input is not None + assert self.frozen_model_input.sampling_metadata is not None self.frozen_model_input.sampling_metadata.selected_token_indices = \ async_tensor_h2d(list(range(self.num_queries)), dtype=torch.long, @@ -293,6 +295,7 @@ def maybe_advance_frozen_model_input(self, device: str, pin_memory: bool): self.maybe_advance_sampling_metadata(device, pin_memory) + # MutableModelInputForGPUWithMultiStepMetadata is not subclass of # ModelInputForGPU but it wraps the actual input dataclass and adds multi-step # metadata @@ -602,7 +605,8 @@ def _update_sampling_metadata(self, sampling_metadata, num_seqs, def _advance_step(self, model_input: StatefulModelInput, out: SamplerOutput) -> StatefulModelInput: - model_input.maybe_advance_frozen_model_input(self.device, self.pin_memory) + model_input.maybe_advance_frozen_model_input(self.device, + self.pin_memory) frozen_model_input = model_input.frozen_model_input assert frozen_model_input is not None assert frozen_model_input.input_tokens is not None