diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 576d906fa749..c655b0fded6e 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -154,6 +154,7 @@ def test_update_states_request_resumed(model_runner): cached_req_data = CachedRequestData( req_id=req_id, resumed_from_preemption=False, + new_token_ids=[], new_block_ids=[], num_computed_tokens=0, ) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 82c4b307d48b..e5c60afeb492 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -121,6 +121,8 @@ def schedule(self) -> "SchedulerOutput": encoder_budget = self.max_num_encoder_input_tokens # Spec decode-related. scheduled_spec_decode_tokens: Dict[str, List[int]] = {} + + # For logging. scheduled_timestamp = time.monotonic() # First, schedule the RUNNING requests. @@ -187,6 +189,15 @@ def schedule(self) -> "SchedulerOutput": token_budget -= num_new_tokens req_index += 1 + # Speculative decode related. + if request.spec_token_ids: + num_scheduled_spec_tokens = (num_new_tokens + + request.num_computed_tokens - + request.num_tokens) + if num_scheduled_spec_tokens > 0: + scheduled_spec_decode_tokens[request.request_id] = ( + request.spec_token_ids[:num_scheduled_spec_tokens]) + # Encoder-related. if encoder_inputs_to_schedule: scheduled_encoder_inputs[request.request_id] = ( @@ -196,11 +207,6 @@ def schedule(self) -> "SchedulerOutput": self.encoder_cache_manager.allocate(request, i) encoder_budget = new_encoder_budget - # Speculative decode related. - if request.spec_token_ids: - scheduled_spec_decode_tokens[ - request.request_id] = request.spec_token_ids - # Record the LoRAs in scheduled_running_reqs requested_loras: Set[int] = set() if self.lora_config: @@ -324,23 +330,24 @@ def schedule(self) -> "SchedulerOutput": # Construct the scheduler output. new_reqs_data = [ NewRequestData.from_request(req, - req_to_new_block_ids[req.request_id], - req.num_computed_tokens) + req_to_new_block_ids[req.request_id]) for req in scheduled_new_reqs ] resumed_reqs_data = [ self._make_cached_request_data( req, + num_scheduled_tokens[req.request_id], + len(scheduled_spec_decode_tokens.get(req.request_id, ())), req_to_new_block_ids[req.request_id], - req.num_computed_tokens, resumed_from_preemption=True, ) for req in scheduled_resumed_reqs ] running_reqs_data = [ self._make_cached_request_data( req, + num_scheduled_tokens[req.request_id], + len(scheduled_spec_decode_tokens.get(req.request_id, ())), req_to_new_block_ids[req.request_id], - req.num_computed_tokens, resumed_from_preemption=False, ) for req in scheduled_running_reqs ] @@ -349,8 +356,8 @@ def schedule(self) -> "SchedulerOutput": scheduled_cached_reqs=resumed_reqs_data + running_reqs_data, num_scheduled_tokens=num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens, - scheduled_encoder_inputs=scheduled_encoder_inputs, scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, + scheduled_encoder_inputs=scheduled_encoder_inputs, num_common_prefix_blocks=num_common_prefix_blocks, # finished_req_ids is an existing state in the scheduler, # instead of being newly scheduled in this step. @@ -366,22 +373,28 @@ def schedule(self) -> "SchedulerOutput": def _make_cached_request_data( self, request: Request, + num_scheduled_tokens: int, + num_scheduled_spec_tokens: int, new_block_ids: List[int], - num_computed_tokens: int, resumed_from_preemption: bool, ) -> "CachedRequestData": # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. - if request.request_id in self._cached_reqs_data: - req_data = self._cached_reqs_data[request.request_id] + num_computed_tokens = request.num_computed_tokens + num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens + new_token_ids = request.all_token_ids[ + num_computed_tokens:num_computed_tokens + num_regular_tokens] + req_data = self._cached_reqs_data.get(request.request_id) + if req_data is not None: req_data.resumed_from_preemption = resumed_from_preemption + req_data.new_token_ids = new_token_ids req_data.new_block_ids = new_block_ids req_data.num_computed_tokens = num_computed_tokens else: req_data = CachedRequestData.from_request(request, resumed_from_preemption, - new_block_ids, - num_computed_tokens) + new_token_ids, + new_block_ids) self._cached_reqs_data[request.request_id] = req_data return req_data diff --git a/vllm/v1/core/scheduler_output.py b/vllm/v1/core/scheduler_output.py index 2ca8526936e6..47413527c32f 100644 --- a/vllm/v1/core/scheduler_output.py +++ b/vllm/v1/core/scheduler_output.py @@ -30,7 +30,6 @@ def from_request( cls, request: "Request", block_ids: List[int], - num_computed_tokens: int, ) -> "NewRequestData": return cls( req_id=request.request_id, @@ -41,7 +40,7 @@ def from_request( mm_positions=request.mm_positions, sampling_params=request.sampling_params, block_ids=block_ids, - num_computed_tokens=num_computed_tokens, + num_computed_tokens=request.num_computed_tokens, lora_request=request.lora_request, ) @@ -54,6 +53,7 @@ class CachedRequestData: # the request's block IDs. If True, new_block_ids will be used as the # request's block IDs instead of appending to the existing block IDs. resumed_from_preemption: bool + new_token_ids: List[int] new_block_ids: List[int] num_computed_tokens: int @@ -62,14 +62,15 @@ def from_request( cls, request: "Request", resumed_from_preemption: bool, + new_token_ids: List[int], new_block_ids: List[int], - num_computed_tokens: int, ) -> "CachedRequestData": return cls( req_id=request.request_id, resumed_from_preemption=resumed_from_preemption, + new_token_ids=new_token_ids, new_block_ids=new_block_ids, - num_computed_tokens=num_computed_tokens, + num_computed_tokens=request.num_computed_tokens, ) @@ -91,9 +92,9 @@ class SchedulerOutput: # Total number of tokens scheduled for all requests. # Equal to sum(num_scheduled_tokens.values()) total_num_scheduled_tokens: int - # req_id -> spec_decode_tokens - # If a request does not have any spec decode tokens, it will - # not be included in the dictionary. + # req_id -> spec_token_ids + # If a request does not have any spec decode tokens, it will not be + # included in the dictionary. scheduled_spec_decode_tokens: Dict[str, List[int]] # req_id -> encoder input indices that need processing. # E.g., if a request has [0, 1], it could mean the vision encoder needs diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f1212c3554b6..e1d1e43427b8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2,7 +2,7 @@ import gc import time -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -184,7 +184,6 @@ def __init__( self.max_model_len, self.max_num_tokens), dtype=np.int32) - self.arange_cpu = torch.from_numpy(self.arange_np) # NOTE(woosuk): These tensors are "stateless", i.e., they are literally # a faster version of creating a new tensor every time. Thus, we should # not make any assumptions about the values in these tensors. @@ -327,7 +326,17 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: req_state = self.requests[req_id] # Update the cached states. - req_state.num_computed_tokens = req_data.num_computed_tokens + num_computed_tokens = req_data.num_computed_tokens + req_state.num_computed_tokens = num_computed_tokens + # Add the sampled token(s) from the previous step (if any). + # This doesn't include "unverified" tokens like spec decode tokens. + num_new_tokens = (num_computed_tokens + + len(req_data.new_token_ids) - + req_state.num_tokens) + new_token_ids = (req_data.new_token_ids[-num_new_tokens:] + if num_new_tokens > 0 else []) + req_state.output_token_ids.extend(new_token_ids) + # Update the block IDs. if not req_data.resumed_from_preemption: # Append the new blocks to the existing block IDs. req_state.block_ids.extend(req_data.new_block_ids) @@ -346,12 +355,30 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( - req_data.num_computed_tokens) - start_index = len(req_state.block_ids) - len( - req_data.new_block_ids) + num_computed_tokens) + start_index = (len(req_state.block_ids) - + len(req_data.new_block_ids)) self.input_batch.block_table.append_row(req_index, start_index, req_data.new_block_ids) + # Add new_token_ids to token_ids_cpu. + start_token_index = num_computed_tokens + end_token_index = num_computed_tokens + len(req_data.new_token_ids) + self.input_batch.token_ids_cpu[ + req_index, + start_token_index:end_token_index] = req_data.new_token_ids + # Add spec_token_ids to token_ids_cpu. + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( + req_id, []) + if spec_token_ids: + start_index = end_token_index + end_token_index += len(spec_token_ids) + self.input_batch.token_ids_cpu[ + req_index, start_index:end_token_index] = spec_token_ids + # NOTE(woosuk): `num_tokens` here may include spec decode tokens. + self.input_batch.num_tokens[req_index] = end_token_index + # Check if the batch has changed. If not, we can skip copying the + # sampling metadata from CPU to GPU. batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0 # Add the new or resumed requests to the persistent batch. @@ -374,7 +401,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: return batch_changed def _prepare_inputs( - self, scheduler_output: "SchedulerOutput" + self, + scheduler_output: "SchedulerOutput", ) -> Tuple[FlashAttentionMetadata, torch.Tensor]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 @@ -387,24 +415,14 @@ def _prepare_inputs( # Get the number of scheduled tokens for each request. # TODO: The Python loop can be slow. Optimize. - num_scheduled_tokens_list: List[int] = [] + num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32) max_num_scheduled_tokens = 0 - all_spec_token_ids: List[int] = [] - num_spec_tokens_list: List[int] = [] for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): assert req_id is not None num_tokens = scheduler_output.num_scheduled_tokens[req_id] - num_scheduled_tokens_list.append(num_tokens) + num_scheduled_tokens[i] = num_tokens max_num_scheduled_tokens = max(max_num_scheduled_tokens, num_tokens) - spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( - req_id, []) - all_spec_token_ids.extend(spec_token_ids) - num_spec_tokens_list.append(len(spec_token_ids)) - - num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list, - dtype=np.int32) - assert max_num_scheduled_tokens > 0 # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] @@ -441,78 +459,6 @@ def _prepare_inputs( token_indices = (positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]) - use_spec_decode = len(all_spec_token_ids) > 0 - if use_spec_decode: - - # 1. Write spec_token_ids to input batch. - # Step 1. Get req indices that perform spec decode and repeat - # the req indices by the number of spec tokens. Note - # for requests that don't perform spec decode, the - # number of spec tokens is 0 and the req index is - # repeated 0 times. - # E.g., num_spec_tokens_list: [3, 0, 2, 0, 1] - # spec_req_indices: [0, 0, 0, 2, 2, 4] - spec_req_indices = np.repeat(self.arange_np[:num_reqs], - num_spec_tokens_list) - # spec_offsets: offsets within each spec token list. - # E.g., [1, 2, 3, 1, 2, 1], TODO: avoid the for loop here - spec_offsets = np.concatenate( - [self.arange_np[1:val + 1] for val in num_spec_tokens_list]) - # spec_seq_offsets: offsets within each sequence. - # E.g., num_computed_tokens_cpu: [1, 4, 3, 6, 2] - # after repeating: [1, 1, 1, 3, 3, 2] - # spec_seq_offsets: [1, 1, 1, 3, 3, 2] + [1, 2, 3, 1, 2, 1] - # = [2, 3, 4, 4, 5, 3] - spec_seq_offsets = np.repeat( - self.input_batch.num_computed_tokens_cpu[:num_reqs], - num_spec_tokens_list) + spec_offsets - # cumsums_spec_offsets: [0, 0, 0, 2M, 2M, 4M] + [2, 3, 4, 4, 5, 3] - cumsums_spec_offsets = ( - spec_seq_offsets + - spec_req_indices * self.input_batch.token_ids_cpu.shape[1]) - cumsums_spec_offsets = torch.from_numpy(cumsums_spec_offsets).to( - torch.int64) - all_spec_token_ids = torch.tensor(all_spec_token_ids, - device="cpu", - dtype=self.input_ids_cpu.dtype) - - # Step 2. Write spec token ids to input_ids_cpu. - self.input_batch.token_ids_cpu_tensor.flatten().scatter_( - 0, cumsums_spec_offsets, all_spec_token_ids) - - # 2. Get spec decode logits indices. - # E.g., num_scheduled_tokens: [4, 100, 3, 100, 2] - # cu_num_tokens: [4, 104, 107, 207, 209] - # num_spec_tokens_list: [3, 0, 2, 0, 1] - # num_sampled_tokens: [4, 1, 3, 1, 2] - # spec_decode_logits_indices: - # [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] - num_spec_tokens_np = np.array(num_spec_tokens_list, dtype=np.int32) - num_sampled_tokens = num_spec_tokens_np + 1 - # logits_start_loc: [0, 103, 104, 206, 207] - logits_start_loc = cu_num_tokens - num_sampled_tokens - # [0, 103, 104, 206, 207] -> - # [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] - logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens) - # The following three lines: - # [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] - # Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11] - cu_num_sampled_tokens = np.cumsum(num_sampled_tokens) - # Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9] - # -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] - cumsums_sampled_offsets = np.repeat( - cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens) - # Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - # - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] - # -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] - total_num_sampled_tokens = num_sampled_tokens.sum() - sampled_arange = (self.arange_np[:total_num_sampled_tokens] - - cumsums_sampled_offsets) - - # [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] -> - # [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] - spec_decode_logits_indices = logits_start_loc + sampled_arange - # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. @@ -606,9 +552,11 @@ def _prepare_inputs( suffix_kv_lens=suffix_kv_lens, ) + use_spec_decode = len( + scheduler_output.scheduled_spec_decode_tokens) > 0 if use_spec_decode: - logits_indices = torch.from_numpy(spec_decode_logits_indices).to( - self.device, non_blocking=True) + logits_indices = self._calc_spec_decode_metadata( + scheduler_output, cu_num_tokens) else: # NOTE(woosuk): Due to chunked prefills, the batch may contain # partial requests. While we should not sample any token @@ -762,6 +710,53 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): mrope_pos_ptr += completion_part_len + def _calc_spec_decode_metadata( + self, + scheduler_output: "SchedulerOutput", + cu_num_tokens: np.ndarray, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Get the number of spec decode tokens for each request. + num_reqs = self.input_batch.num_reqs + num_spec_decode_tokens = np.empty(num_reqs, dtype=np.int32) + for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): + assert req_id is not None + num_spec_decode_tokens[i] = len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) + + # Get spec decode logits indices. + # E.g., num_scheduled_tokens: [4, 100, 3, 100, 2] + # cu_num_tokens: [4, 104, 107, 207, 209] + # num_spec_tokens_list: [3, 0, 2, 0, 1] + # num_sampled_tokens: [4, 1, 3, 1, 2] + # spec_decode_logits_indices: + # [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] + num_sampled_tokens = num_spec_decode_tokens + 1 + # logits_start_loc: [0, 103, 104, 206, 207] + logits_start_loc = cu_num_tokens - num_sampled_tokens + # [0, 103, 104, 206, 207] -> + # [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] + logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens) + # The following three lines: + # [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] + # Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11] + cu_num_sampled_tokens = np.cumsum(num_sampled_tokens) + # Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9] + # -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] + cumsums_sampled_offsets = np.repeat( + cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens) + # Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + # - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] + # -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] + total_num_sampled_tokens = num_sampled_tokens.sum() + sampled_arange = (self.arange_np[:total_num_sampled_tokens] - + cumsums_sampled_offsets) + + # [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] -> + # [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] + spec_decode_logits_indices = logits_start_loc + sampled_arange + return torch.from_numpy(spec_decode_logits_indices).to( + self.device, non_blocking=True) + def _prepare_sampling( self, batch_changed: bool, @@ -773,7 +768,9 @@ def _prepare_sampling( for req_id, req in self.requests.items()} sampling_metadata = self.input_batch.make_sampling_metadata( - req_id_output_token_ids, req_to_spec_token_ids, not batch_changed) + req_id_output_token_ids, + req_to_spec_token_ids, + skip_copy=not batch_changed) return sampling_metadata def _execute_encoder(self, scheduler_output: "SchedulerOutput"): @@ -960,28 +957,24 @@ def execute_model( # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. num_reqs = self.input_batch.num_reqs - request_seq_lens: List[Tuple[int, CachedRequestState, int]] = [] + req_ids: List[str] = [] + # Because `input_batch.req_ids` is a list of length `max_num_reqs`, + # we need to stop at `num_reqs`. + # FIXME(woosuk): This is hacky. Refactor. for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): assert req_id is not None + req_ids.append(req_id) req_state = self.requests[req_id] seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) - if seq_len >= req_state.num_tokens: - request_seq_lens.append((i, req_state, seq_len)) - else: - # Ignore the sampled token from the partial request. + if seq_len < req_state.num_tokens: + # Ignore the sampled token. # Rewind the generator state as if the token was not sampled. generator = self.input_batch.generators.get(i) if generator is not None: # This relies on cuda-specific torch-internal impl details generator.set_offset(generator.get_offset() - 4) - # num_reqs entries should be non-None - assert all( - req_id is not None for req_id in - self.input_batch.req_ids[:num_reqs]), "req_ids contains None" - req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) - # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. logprobs_tensors = sampler_output.logprobs_tensors @@ -994,29 +987,21 @@ def execute_model( scheduler_output, ) - # Update batch with the valid generated tokens. + # Get the valid generated tokens. sampled_token_ids = sampler_output.sampled_token_ids max_gen_len = sampled_token_ids.shape[-1] if max_gen_len == 1: + # No spec decode tokens. valid_sampled_token_ids = sampled_token_ids.tolist() - for i, req_state, seq_len in request_seq_lens: - token_id = valid_sampled_token_ids[i][0] - self.input_batch.token_ids_cpu[i, seq_len] = token_id - req_state.output_token_ids.append(token_id) - self.input_batch.num_tokens[i] += 1 else: + # Includes spec decode tokens. valid_mask = sampled_token_ids != INVALID_TOKEN_ID gen_lens = valid_mask.sum(dim=1).tolist() + # TODO(woosuk): Optimize this. valid_sampled_token_ids = [ seq.tolist() for seq in sampled_token_ids[valid_mask].split(gen_lens) ] - self.input_batch.num_tokens[:num_reqs] += gen_lens - for i, req_state, seq_len in request_seq_lens: - target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1) - self.input_batch.token_ids_cpu[ - i, target_slice] = valid_sampled_token_ids[i] - req_state.output_token_ids.extend(valid_sampled_token_ids[i]) model_runner_output = ModelRunnerOutput( req_ids=req_ids,