diff --git a/vllm/config.py b/vllm/config.py index 50adfe8f2d78..5e65db9ef767 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2255,7 +2255,7 @@ def __post_init__(self): SpeculativeMethod = Literal["ngram", "eagle", "medusa", "mlp_speculator", - "draft_model"] + "draft_model", "deepseek_mtp"] SpeculativeAcceptanceMethod = Literal["rejection_sampler", "typical_acceptance_sampler"] @@ -2519,6 +2519,15 @@ def __post_init__(self): elif (self.draft_model_config.hf_config.model_type == "mlp_speculator"): self.method = "mlp_speculator" + elif (self.draft_model_config.hf_config.model_type == + "deepseek_mtp"): + self.method = "deepseek_mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "All Deepseek MTP models only have " \ + "one layer. Might need some code changes " \ + "to support multiple layers." + ) else: self.method = "draft_model" @@ -2739,7 +2748,7 @@ def num_lookahead_slots(self) -> int: return self.num_speculative_tokens def use_eagle(self) -> bool: - return self.method in ("eagle", "eagle3") + return self.method in ("eagle", "eagle3", "deepseek_mtp") def __repr__(self) -> str: method = self.method diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 12c306e98048..b561a1a77487 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1338,7 +1338,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: is_ngram_enabled = True elif speculative_method == "medusa": is_medusa_enabled = True - elif speculative_method in ("eagle", "eagle3"): + elif speculative_method in ("eagle", "eagle3", "deepseek_mtp"): is_eagle_enabled = True else: speculative_model = self.speculative_config.get("model") diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 6d7b52aba5f9..03ef7bed0edc 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -19,6 +19,7 @@ from .deepseek_v2 import (DeepseekV2DecoderLayer, get_spec_layer_idx_from_weight_name) +from .interfaces import SupportsPP from .utils import maybe_prefix @@ -145,7 +146,7 @@ def compute_logits( return logits -class DeepSeekMTP(nn.Module): +class DeepSeekMTP(nn.Module, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 460d645a1a6c..3926a86ee591 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -10,9 +10,10 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM -from vllm.triton_utils import tl, triton -from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata, + FlashAttentionMetadata) from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel logger = init_logger(__name__) @@ -25,12 +26,15 @@ def __init__( self, vllm_config: VllmConfig, device: torch.device, + runner=None, ): self.vllm_config = vllm_config self.speculative_config = vllm_config.speculative_config self.draft_model_config = self.speculative_config.draft_model_config self.method = self.speculative_config.method + self.runner = runner + self.dtype = vllm_config.model_config.dtype self.max_model_len = vllm_config.model_config.max_model_len self.block_size = vllm_config.cache_config.block_size @@ -106,24 +110,46 @@ def propose( # FA requires seq_len to have dtype int32. seq_lens = (target_positions[last_token_indices] + 1).int() - # FIXME(woosuk): The below two ops cause synchronization. Optimize. - max_seq_len = seq_lens.max().item() - max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item() - attn_metadata = FlashAttentionMetadata( - num_actual_tokens=num_tokens, - max_query_len=max_num_tokens, - query_start_loc=cu_num_tokens, - max_seq_len=max_seq_len, - seq_lens=seq_lens, - block_table=block_table, - slot_mapping=target_slot_mapping, - # TODO(woosuk): Support cascade attention. - use_cascade=False, - common_prefix_len=0, - cu_prefix_query_lens=None, - prefix_kv_lens=None, - suffix_kv_lens=None, - ) + if self.method in ["eagle", "eagle3"]: + # FIXME(woosuk): The below two ops cause synchronization. Optimize. + max_seq_len = seq_lens.max().item() + max_num_tokens = (cu_num_tokens[1:] - + cu_num_tokens[:-1]).max().item() + attn_metadata = FlashAttentionMetadata( + num_actual_tokens=num_tokens, + max_query_len=max_num_tokens, + query_start_loc=cu_num_tokens, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table=block_table, + slot_mapping=target_slot_mapping, + # TODO(woosuk): Support cascade attention. + use_cascade=False, + common_prefix_len=0, + cu_prefix_query_lens=None, + prefix_kv_lens=None, + suffix_kv_lens=None, + ) + elif self.method == "deepseek_mtp": + query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] + max_query_len = query_lens.max().item() + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=cu_num_tokens, seq_lens=seq_lens) + + assert self.runner is not None + + # FIXME: need to consider multiple kv_cache_groups + attn_metadata = self.runner.attn_metadata_builder.build( + num_reqs=batch_size, + num_actual_tokens=num_tokens, + max_query_len=max_query_len, + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + else: + raise ValueError(f"Unsupported method: {self.method}") + if self.use_cuda_graph and \ num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) @@ -136,11 +162,15 @@ def propose( with set_forward_context(attn_metadata, self.vllm_config, num_tokens=num_input_tokens): - last_hidden_states, hidden_states = self.model( - input_ids=self.input_ids[:num_input_tokens], - positions=self.positions[:num_input_tokens], - hidden_states=self.hidden_states[:num_input_tokens], + ret_hidden_states = self.model( + self.input_ids[:num_input_tokens], + self.positions[:num_input_tokens], + self.hidden_states[:num_input_tokens], ) + if self.method == "deepseek_mtp": + last_hidden_states = ret_hidden_states + else: + last_hidden_states, hidden_states = ret_hidden_states sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) draft_token_ids = logits.argmax(dim=-1) @@ -150,6 +180,10 @@ def propose( # [batch_size, 1] return draft_token_ids.view(-1, 1) + # TODO: Currently, MTP module released by deepseek only has + # one layer. Adapt this code to support multiple layers once + # there's a multi-layer MTP module. + # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] @@ -215,9 +249,9 @@ def propose( self.vllm_config, num_tokens=input_batch_size): last_hidden_states, hidden_states = self.model( - input_ids=self.input_ids[:input_batch_size], - positions=self.positions[:input_batch_size], - hidden_states=self.hidden_states[:input_batch_size], + self.input_ids[:input_batch_size], + self.positions[:input_batch_size], + self.hidden_states[:input_batch_size], ) hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size], @@ -268,7 +302,7 @@ def prepare_inputs( batch_size = num_rejected_tokens.shape[0] BLOCK_SIZE = 1024 - prepare_input_kernel[(batch_size, )]( + prepare_eagle_input_kernel[(batch_size, )]( token_indices, cu_target_query_lens, cu_num_tokens, @@ -320,9 +354,9 @@ def dummy_run( with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): self.model( - input_ids=self.input_ids[:num_tokens], - positions=self.positions[:num_tokens], - hidden_states=self.hidden_states[:num_tokens], + self.input_ids[:num_tokens], + self.positions[:num_tokens], + self.hidden_states[:num_tokens], ) @@ -367,29 +401,3 @@ def compute_probs_and_sample_next_token( next_token_ids, ) return next_token_ids, probs - - -@triton.jit -def prepare_input_kernel( - out_ptr, - cu_query_lens_ptr, - cu_num_tokens_ptr, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) - - # [start_pos, end_pos) - start_pos = tl.load(cu_num_tokens_ptr + pid) - end_pos = tl.load(cu_num_tokens_ptr + pid + 1) - num_tokens = end_pos - start_pos - - index_start = tl.load(cu_query_lens_ptr + pid) - - num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE) - for i in tl.range(num_blocks): - offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - tl.store( - out_ptr + start_pos + offset, - index_start + offset, - mask=offset < num_tokens, - ) diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index ce81a40ee3ae..334258e7f87a 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +from vllm.triton_utils import tl, triton from vllm.v1.worker.gpu_input_batch import InputBatch @@ -16,3 +17,29 @@ def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool: return False return True + + +@triton.jit +def prepare_eagle_input_kernel( + out_ptr, + cu_query_lens_ptr, + cu_num_tokens_ptr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + + # [start_pos, end_pos) + start_pos = tl.load(cu_num_tokens_ptr + pid) + end_pos = tl.load(cu_num_tokens_ptr + pid + 1) + num_tokens = end_pos - start_pos + + index_start = tl.load(cu_query_lens_ptr + pid) + + num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE) + for i in tl.range(num_blocks): + offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + tl.store( + out_ptr + start_pos + offset, + index_start + offset, + mask=offset < num_tokens, + ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6d4888363d50..6cdcc3152dc1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -188,12 +188,16 @@ def __init__( self.use_aux_hidden_state_outputs = False if self.speculative_config: self.use_spec_decode = True + + # NOTE(Jiayi): currently we put the entire draft model on + # the last PP rank. This is not ideal if there are many + # layers in the draft model. if get_pp_group().is_last_rank: if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) elif self.speculative_config.use_eagle(): - self.drafter = EagleProposer(self.vllm_config, - self.device) # type: ignore + self.drafter = EagleProposer(self.vllm_config, self.device, + self) # type: ignore if self.speculative_config.method == "eagle3": self.use_aux_hidden_state_outputs = True elif self.speculative_config.method == "medusa": @@ -1362,6 +1366,12 @@ def execute_model( device=self.device) eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name] + # NOTE: deepseek_mtp uses MLA which does not have `block_table` + if hasattr(eagle_attn_metadata, "block_table"): + block_table = eagle_attn_metadata.block_table + else: + block_table = None + if spec_decode_metadata is None: # input_ids can be None for multimodal models. target_token_ids = self.input_ids[:num_scheduled_tokens] @@ -1407,7 +1417,7 @@ def execute_model( target_slot_mapping=target_slot_mapping, next_token_ids=next_token_ids, cu_num_tokens=cu_num_tokens, - block_table=eagle_attn_metadata.block_table, + block_table=block_table, sampling_metadata=sampling_metadata, ) spec_token_ids = draft_token_ids.tolist() @@ -1718,8 +1728,7 @@ def _dummy_run( else: hidden_states = outputs - if self.use_spec_decode and \ - self.speculative_config.method in ('eagle', 'eagle3'): + if self.use_spec_decode and self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) self.drafter.dummy_run(num_tokens)