Skip to content

Conversation

YaoJiayi
Copy link
Contributor

@YaoJiayi YaoJiayi commented May 20, 2025

Tested on Deepseek R1 with (1) TP=8 and (2) TP=4 * PP=2.

TODOs:

  • Unify MTP (MLA) and EAGLE (normal attention) code paths.
  • Benchmarking performance.
  • Optimize model layers (target+draft) allocation when there are more draft layers (i.e., currently draft model is place on last pp rank).

ruisearch42 and others added 3 commits May 5, 2025 23:44
Signed-off-by: YaoJiayi <[email protected]>
Signed-off-by: YaoJiayi <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

YaoJiayi added 3 commits May 20, 2025 22:05
Signed-off-by: YaoJiayi <[email protected]>
Signed-off-by: YaoJiayi <[email protected]>
Signed-off-by: YaoJiayi <[email protected]>
@mergify mergify bot added the ci/build label May 20, 2025
YaoJiayi added 3 commits May 21, 2025 01:47
Signed-off-by: YaoJiayi <[email protected]>
Signed-off-by: YaoJiayi <[email protected]>
Signed-off-by: YaoJiayi <[email protected]>
@YaoJiayi YaoJiayi marked this pull request as ready for review May 21, 2025 02:22
@YaoJiayi YaoJiayi changed the title [WIP][V1] Support Deepseek MTP [V1] Support Deepseek MTP May 21, 2025
Signed-off-by: YaoJiayi <[email protected]>
@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label May 23, 2025
@WoosukKwon
Copy link
Collaborator

@YaoJiayi LGTM except the minor issue above. Could you please run the deepseek model locally and see it could generate a reasonable output with a reasonable acceptance rate?

Signed-off-by: YaoJiayi <[email protected]>
@YaoJiayi
Copy link
Contributor Author

@YaoJiayi LGTM except the minor issue above. Could you please run the deepseek model locally and see it could generate a reasonable output with a reasonable acceptance rate?

@WoosukKwon I tested on Deepseek-R1 with 10 simple prompts. Outputs are reasonabe and acceptance rate are 30-70%.

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@YaoJiayi Great! Thanks for the amazing work!

@DarkLight1337
Copy link
Member

PTAL at the failing V1 test

YaoJiayi added 2 commits May 23, 2025 12:16
Signed-off-by: YaoJiayi <[email protected]>
Signed-off-by: YaoJiayi <[email protected]>
@WoosukKwon WoosukKwon enabled auto-merge (squash) May 23, 2025 15:10
@vllm-bot vllm-bot merged commit 2628a69 into vllm-project:main May 23, 2025
65 of 67 checks passed
@ruisearch42 ruisearch42 mentioned this pull request May 23, 2025
3 tasks
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
Signed-off-by: Rui Qiao <[email protected]>
Signed-off-by: YaoJiayi <[email protected]>
Co-authored-by: Rui Qiao <[email protected]>
Signed-off-by: Yuqi Zhang <[email protected]>
gshtras added a commit to ROCm/vllm that referenced this pull request May 27, 2025
* Add files via uploadAdd fused MoE kernel tuning configs (fp8_w8a8) for DeepSeek V3/R1 on a single-node 8x NVIDIA H20 96GB setup (vllm-project#18337)

* [Misc] Fix typo (vllm-project#18330)

* Neuron up mistral (vllm-project#18222)

Signed-off-by: Satyajith Chilappagari <[email protected]>

* fix CUDA_check redefinition in vllm-project#17918 (vllm-project#18287)

Signed-off-by: Lucia Fang <[email protected]>
Co-authored-by: Lucia (Lu) Fang <[email protected]>

* [neuron] fix authorization issue (vllm-project#18364)

Signed-off-by: Liangfu Chen <[email protected]>

* [Misc] Allow `AutoWeightsLoader` to skip loading weights with specific substr in name (vllm-project#18358)

Signed-off-by: Isotr0py <[email protected]>

* [Core] [Bugfix]: tensor parallel with prompt embeds (vllm-project#18171)

Signed-off-by: Nan2018 <[email protected]>
Co-authored-by: Andrew Sansom <[email protected]>

* [release] Change dockerhub username for TPU release (vllm-project#18389)

* [Bugfix] fix adding bias twice in ipex GPTQ quantization (vllm-project#18363)

Signed-off-by: rand-fly <[email protected]>

* [doc] update env variable export (vllm-project#18391)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [Misc] Add LoRA code owner (vllm-project#18387)

Signed-off-by: Jee Jee Li <[email protected]>

* Update cpu.txt (vllm-project#18398)

Signed-off-by: 汪志鹏 <[email protected]>

* [CI] Add mteb testing to test the accuracy of the embedding model (vllm-project#17175)

* [Bugfix] Fix MRoPE Errors in the Qwen-VL Model When Processing Pure Text (vllm-project#18407)

Co-authored-by: 松灵 <[email protected]>

* [Misc] refactor prompt embedding examples (vllm-project#18405)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [Minor] Rename quantization nvfp4 to modelopt_fp4 (vllm-project#18356)

Signed-off-by: mgoin <[email protected]>

* [Model] use AutoWeightsLoader for bloom (vllm-project#18300)

Signed-off-by: calvin chen <[email protected]>

* [Kernel] update comment for KV shape in unified triton attn (vllm-project#18099)

Signed-off-by: haochengxia <[email protected]>

* fix:Build torch wheel inline rather than picking from nightly (vllm-project#18351)

Signed-off-by: Dilip Gowda Bhagavan <[email protected]>

* [TPU] Re-enable the Pallas MoE kernel (vllm-project#18025)

Signed-off-by: Michael Goin <[email protected]>

* [Bugfix] config.head_dim is now explicitly set to None (vllm-project#18432)

Signed-off-by: Gregory Shtrasberg <[email protected]>

* [Bug] Fix moe_sum signature (vllm-project#18440)

Signed-off-by: Bill Nell <[email protected]>

* Revert "[Bugfix] Fix MRoPE Errors in the Qwen-VL Model When Processing Pure Text (vllm-project#18407)" (vllm-project#18456)

Signed-off-by: DarkLight1337 <[email protected]>

* [Bugfix][Failing Test] Fix nixl connector test when promt size < block size (vllm-project#18429)

Signed-off-by: wwl2755 <[email protected]>

* [Misc] MultiConnector._connectors type (vllm-project#18423)

Signed-off-by: nicklucche <[email protected]>

* [Frontend] deprecate `--device` arg (vllm-project#18399)

Signed-off-by: Kebe <[email protected]>

* [V1] Fix general plugins not loaded in engine for multiproc (vllm-project#18326)

Signed-off-by: Yong Hoon Shin <[email protected]>

* [Misc] refactor disaggregated-prefill-v1 example (vllm-project#18474)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [Bugfix][Failing Test] Fix test_events.py (vllm-project#18460)

Signed-off-by: rabi <[email protected]>

* [MODEL] FalconH1 (vllm-project#18406)

Signed-off-by: dhia.rhaiem <[email protected]>
Co-authored-by: younesbelkada <[email protected]>
Co-authored-by: Ilyas Chahed <[email protected]>
Co-authored-by: Jingwei Zuo <[email protected]>

* [Doc] fix arg docstring in linear layers (vllm-project#18410)

Signed-off-by: giantcroc <[email protected]>

* [Bugfix] Reduce moe_sum test size to avoid OOM (vllm-project#18484)

Signed-off-by: Bill Nell <[email protected]>

* [Build] fix Dockerfile shell (vllm-project#18402)

* [Misc] Update deprecation message for `--enable-reasoning` (vllm-project#18404)

* [ROCm][Kernel][V1] Enable AMD Radeon GPU Custom Paged Attention on v1 (vllm-project#17004)

Signed-off-by: Hosang Yoon <[email protected]>

* Remove incorrect env value

* Revert "[v1] Support multiple KV cache groups in GPU model runner (vllm-project#17945) (vllm-project#18459)

Signed-off-by: Mark McLoughlin <[email protected]>

* [FEAT][ROCm] Upgrade AITER MLA v1 backend (vllm-project#18338)

Signed-off-by: vllmellm <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>

* [Bugfix] Consistent ascii handling in tool parsers (vllm-project#17704)

Signed-off-by: Sebastian Schönnenbeck <[email protected]>

* [FalconH1] Fix output dtype in RMSNorm fallback path for Falcon-H1 (e.g. 0.5B) (vllm-project#18500)

Signed-off-by: dhia.rhaiem <[email protected]>
Co-authored-by: younesbelkada <[email protected]>
Co-authored-by: Ilyas Chahed <[email protected]>
Co-authored-by: Jingwei Zuo <[email protected]>

* [MISC] update project urls in pyproject.toml (vllm-project#18519)

Signed-off-by: Andy Xie <[email protected]>

* [CI] Fix race condition with StatelessProcessGroup.barrier (vllm-project#18506)

Signed-off-by: Russell Bryant <[email protected]>

* Intialize io_thread_pool attribute in the beginning. (vllm-project#18331)

Signed-off-by: rabi <[email protected]>

* [Bugfix] Inconsistent token calculation compared to HF in llava family (vllm-project#18479)

Signed-off-by: jaycha <[email protected]>

* [BugFix][DP] Send DP wave completion only from `dp_rank==0` (vllm-project#18502)

Signed-off-by: Nick Hill <[email protected]>
Co-authored-by: kourosh hakhamaneshi <[email protected]>

* [Bugfix][Model] Make Olmo2Model weight loading return loaded weights (vllm-project#18504)

Signed-off-by: Shane A <[email protected]>

* [Bugfix] Fix LoRA test (vllm-project#18518)

Signed-off-by: Jee Jee Li <[email protected]>

* [Doc] Fix invalid JSON in example args (vllm-project#18527)

Signed-off-by: DarkLight1337 <[email protected]>

* [Neuron] Update Dockerfile.neuron to use latest neuron release (2.23) (vllm-project#18512)

Signed-off-by: Satyajith Chilappagari <[email protected]>

* Update default neuron config for speculation (vllm-project#18274)

Signed-off-by: Elaine Zhao <[email protected]>
Co-authored-by: Shashwat Srijan <[email protected]>
Co-authored-by: Aakash Shetty <[email protected]>

* Order sequence ids + config update to support specifying custom quantization layers (vllm-project#18279)

Signed-off-by: Elaine Zhao <[email protected]>
Co-authored-by: Tailin Pan <[email protected]>
Co-authored-by: Rishabh Rajesh <[email protected]>
Co-authored-by: Yishan McNabb <[email protected]>
Co-authored-by: Patrick Lange <[email protected]>
Co-authored-by: Maxwell Goldberg <[email protected]>
Co-authored-by: Aakash Shetty <[email protected]>

* [Bugfix] Fix MRoPE Errors in the Qwen-VL Model When Processing Pure Text (vllm-project#18526)

Co-authored-by: 松灵 <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: DarkLight1337 <[email protected]>

* [Bugfix] Add kwargs to RequestOutput __init__ to be forward compatible (vllm-project#18513)

Signed-off-by: Linkun <[email protected]>

* [CI/Build] Update bamba test model location (vllm-project#18544)

Signed-off-by: Harry Mellor <[email protected]>

* [Doc] Support --stream arg in openai_completion_client.py script (vllm-project#18388)

Signed-off-by: googs1025 <[email protected]>

* [Bugfix] Use random hidden states in dummy sampler run (vllm-project#18543)

Signed-off-by: Bowen Wang <[email protected]>

* [Doc] Add stream flag for chat completion example (vllm-project#18524)

Signed-off-by: calvin chen <[email protected]>

* [BugFix][CPU] Fix x86 SHM distributed module initialization (vllm-project#18536)

Signed-off-by: jiang.li <[email protected]>

* [Misc] improve Automatic Prefix Caching example (vllm-project#18554)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [Misc] Call `ndarray.tobytes()` directly instead of `ndarray.data.tobytes()` (vllm-project#18347)

Signed-off-by: Lukas Geiger <[email protected]>

* [Bugfix] make `test_openai_schema.py` pass (vllm-project#18224)

Signed-off-by: David Xia <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>

* [Platform] Move platform check to right place (vllm-project#18470)

Signed-off-by: wangxiyuan <[email protected]>

* [Compile][Platform] Make PiecewiseBackend pluggable and extendable (vllm-project#18076)

Signed-off-by: Mengqing Cao <[email protected]>
Co-authored-by: youkaichao <[email protected]>

* [Build/CI] Fix CUDA 11.8 build (vllm-project#17679)

Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>

* [Tool] Add NIXL installation script (vllm-project#18172)

Signed-off-by: Linkun <[email protected]>

* [V1][Spec Decode][Bugfix] Load quantize weights for EAGLE (vllm-project#18290)

* [Frontend][Bug Fix] Update llama4 pythonic jinja template and llama4_pythonic parser (vllm-project#17917)

Signed-off-by: Kai Wu <[email protected]>

* [Frontend] [Core] Add Tensorizer support for V1, LoRA adapter serialization and deserialization (vllm-project#17926)

Signed-off-by: Sanger Steel <[email protected]>

* [AMD] [P/D] Compute num gpus for ROCm correctly in run_accuracy_test.sh (vllm-project#18568)

Signed-off-by: Randall Smith <[email protected]>

* Re-submit: Fix: Proper RGBA -> RGB conversion for PIL images. (vllm-project#18569)

Signed-off-by: Chenheli Hua <[email protected]>

* [V1][Spec Decoding] Use model_loader.get_model() to load models (vllm-project#18273)

Signed-off-by: Mark McLoughlin <[email protected]>

* Enable hybrid attention models for Transformers backend (vllm-project#18494)

Signed-off-by: Harry Mellor <[email protected]>

* [Misc] refactor: simplify input validation and num_requests handling in _convert_v1_inputs (vllm-project#18482)

Signed-off-by: googs1025 <[email protected]>

* [BugFix] Increase TP execute_model timeout (vllm-project#18558)

Signed-off-by: Nick Hill <[email protected]>

* [Bugfix] Set `KVTransferConfig.engine_id` in post_init (vllm-project#18576)

Signed-off-by: Linkun Chen <[email protected]>

* [Spec Decode] Make EAGLE3 draft token ID mapping optional (vllm-project#18488)

Signed-off-by: Benjamin Chislett <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>

* [Neuron] Remove bypass on EAGLEConfig and add a test (vllm-project#18514)

Signed-off-by: Elaine Zhao <[email protected]>

* [Bugfix][Benchmarks] Fix a benchmark of deepspeed-mii backend to use api_key (vllm-project#17291)

Signed-off-by: Teruaki Ishizaki <[email protected]>

* [Misc] Replace `cuda` hard code with `current_platform` (vllm-project#16983)

Signed-off-by: shen-shanshan <[email protected]>

* [Hardware] correct method signatures for HPU,ROCm,XPU (vllm-project#18551)

Signed-off-by: Andy Xie <[email protected]>

* [V1] [Bugfix] eagle bugfix and enable correct lm_head for multimodal (vllm-project#18034)

Signed-off-by: Ronald Xu <[email protected]>

* [Feature]Add async tensor parallelism using compilation pass (vllm-project#17882)

Signed-off-by: cascade812 <[email protected]>

* [Doc] Update quickstart and install for cu128 using `--torch-backend=auto` (vllm-project#18505)

Signed-off-by: mgoin <[email protected]>

* [Feature][V1]: suupports cached_tokens in response usage (vllm-project#18149)

Co-authored-by: simon-mo <[email protected]>

* [Bugfix] Add half type support in reshape_and_cache_cpu_impl on x86 cpu platform (vllm-project#18430)

Signed-off-by: Yuqi Zhang <[email protected]>
Co-authored-by: Yuqi Zhang <[email protected]>

* Migrate docs from Sphinx to MkDocs (vllm-project#18145)

Signed-off-by: Harry Mellor <[email protected]>

* Revert "[V1] [Bugfix] eagle bugfix and enable correct lm_head for multimodal (vllm-project#18034)" (vllm-project#18600)

Signed-off-by: DarkLight1337 <[email protected]>

* [Bugfix][Model] Fix baichuan model loader for tp (vllm-project#18597)

Signed-off-by: Mengqing Cao <[email protected]>

* [V0][Bugfix] Fix parallel sampling performance regression when guided decoding is enabled (vllm-project#17731)

Signed-off-by: Madeesh Kannan <[email protected]>
Co-authored-by: Russell Bryant <[email protected]>

* Add myself as docs code owner (vllm-project#18605)

Signed-off-by: Harry Mellor <[email protected]>

* [Hardware][CPU] Update intel_extension_for_pytorch 2.7.0 and move to `requirements/cpu.txt`  (vllm-project#18542)

Signed-off-by: Kay Yan <[email protected]>

* [CI] fix kv_cache_type argument (vllm-project#18594)

Signed-off-by: Andy Xie <[email protected]>

* [Doc] Fix indent of contributing to vllm (vllm-project#18611)

Signed-off-by: Zerohertz <[email protected]>

* Replace `{func}` with mkdocs style links (vllm-project#18610)

Signed-off-by: Harry Mellor <[email protected]>

* [CI/Build] Fix V1 flag being set in entrypoints tests (vllm-project#18598)

Signed-off-by: DarkLight1337 <[email protected]>

* Fix examples with code blocks in docs (vllm-project#18609)

Signed-off-by: Harry Mellor <[email protected]>

* [Bugfix] Fix transformers model impl ignored for mixtral quant (vllm-project#18602)

Signed-off-by: Tristan Leclercq <[email protected]>

* Include private attributes in API documentation (vllm-project#18614)

Signed-off-by: Harry Mellor <[email protected]>

* [Misc] add Haystack integration (vllm-project#18601)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [Bugfix][Build/CI] Fixup CUDA compiler version check for CUDA_SUPPORTED_ARCHS (vllm-project#18579)

* [Doc] Fix markdown list indentation for MkDocs rendering (vllm-project#18620)

Signed-off-by: Zerohertz <[email protected]>

* [Doc] Use a different color for the announcement (vllm-project#18616)

Signed-off-by: DarkLight1337 <[email protected]>

* Refactor pplx init logic to make it modular (prepare for deepep) (vllm-project#18200)

Signed-off-by: youkaichao <[email protected]>

* Fix figures in design doc (vllm-project#18612)

Signed-off-by: Harry Mellor <[email protected]>

* [Docs] Change mkdocs to not use directory urls (vllm-project#18622)

Signed-off-by: mgoin <[email protected]>

* [v1] Redo "Support multiple KV cache groups in GPU model runner (vllm-project#17945)" (vllm-project#18593)

Signed-off-by: Chen Zhang <[email protected]>

* [Doc] fix list formatting (vllm-project#18624)

Signed-off-by: David Xia <[email protected]>

* [Doc] Fix top-level API links/docs (vllm-project#18621)

Signed-off-by: DarkLight1337 <[email protected]>

* [Doc] Avoid documenting dynamic / internal modules (vllm-project#18626)

Signed-off-by: DarkLight1337 <[email protected]>

* [Doc] Fix broken links and unlinked docs, add shortcuts to home sidebar (vllm-project#18627)

Signed-off-by: DarkLight1337 <[email protected]>

* [V1] Support Deepseek MTP (vllm-project#18435)

Signed-off-by: Rui Qiao <[email protected]>
Signed-off-by: YaoJiayi <[email protected]>
Co-authored-by: Rui Qiao <[email protected]>

* Use prebuilt FlashInfer x86_64 PyTorch 2.7 CUDA 12.8 wheel for CI (vllm-project#18537)

Signed-off-by: Huy Do <[email protected]>

* [CI] Enable test_initialization to run on V1 (vllm-project#16736)

Signed-off-by: mgoin <[email protected]>

* [Doc] Update references to doc files (vllm-project#18637)

Signed-off-by: DarkLight1337 <[email protected]>

* [ModelOpt] Introduce VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE env var to control blockscale tensor allocation (vllm-project#18160)

Signed-off-by: Pavani Majety <[email protected]>

* [Bugfix] Migrate to REGEX Library to prevent catastrophic backtracking (vllm-project#18454)

Signed-off-by: Crucifixion-Fxl <[email protected]>
Co-authored-by: Crucifixion-Fxl <[email protected]>

* [Bugfix][Nixl] Fix Preemption Bug (vllm-project#18631)

Signed-off-by: [email protected] <[email protected]>

* config.py: Clarify that only local GGUF checkpoints are supported. (vllm-project#18623)

Signed-off-by: Mathieu Bordere <[email protected]>

* FIX MOE issue in AutoRound format (vllm-project#18586)

Signed-off-by: wenhuach21 <[email protected]>

* [V1][Spec Decode] Small refactors to improve eagle bookkeeping performance (vllm-project#18424)

Signed-off-by: qizixi <[email protected]>

* [Frontend] improve vllm serve --help display (vllm-project#18643)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [Model] Add support for Qwen2.5-Omni-7B-AWQ (Qwen2_5OmniForConditionalGeneration) (vllm-project#18647)

* [V1][Spec Decode] Support multi-layer eagle draft model (vllm-project#18030)

Signed-off-by: qizixi <[email protected]>

* [Doc] Update README links, mark external links (vllm-project#18635)

Signed-off-by: DarkLight1337 <[email protected]>

* [MISC][pre-commit] Add pre-commit check for triton import (vllm-project#17716)

Signed-off-by: Mengqing Cao <[email protected]>

* [Doc] Fix indentation problems in V0 Paged Attention docs (vllm-project#18659)

Signed-off-by: DarkLight1337 <[email protected]>

* [Doc] Add community links (vllm-project#18657)

Signed-off-by: DarkLight1337 <[email protected]>

* [Model] use AutoWeightsLoader for gpt2 (vllm-project#18625)

Signed-off-by: zt2370 <[email protected]>

* [Doc] Reorganize user guide (vllm-project#18661)

Signed-off-by: DarkLight1337 <[email protected]>

* [CI/Build] `chmod +x` to `cleanup_pr_body.sh` (vllm-project#18650)

Signed-off-by: DarkLight1337 <[email protected]>

* [MISC] typo fix and clean import (vllm-project#18664)

Signed-off-by: Andy Xie <[email protected]>

* [BugFix] Fix import error for fused_moe (vllm-project#18642)

Signed-off-by: wangxiyuan <[email protected]>

* [CI] enforce import regex instead of re (vllm-project#18665)

Signed-off-by: Aaron Pham <[email protected]>

* fix(regression): clone from reference items (vllm-project#18662)

Signed-off-by: Aaron Pham <[email protected]>

* [CI/Build] fix permission denied issue (vllm-project#18645)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [BugFix][Spec Decode] Improve Prefix Caching Logic in Speculative Decoding (vllm-project#18668)

Signed-off-by: Woosuk Kwon <[email protected]>

* [V1] Fix _pickle.PicklingError: Can't pickle <class 'transformers_modules.deepseek-ai.DeepSeek-V2-Lite... (vllm-project#18640)

Signed-off-by: Seiji Eicher <[email protected]>

* [MISC] correct signature for LoaderFunction (vllm-project#18670)

Signed-off-by: Andy Xie <[email protected]>

* [Misc]Replace `cuda` hard code with `current_platform` in Ray (vllm-project#14668)

Signed-off-by: noemotiovon <[email protected]>

* [Misc][ModelScope] Change to use runtime VLLM_USE_MODELSCOPE (vllm-project#18655)

Signed-off-by: Mengqing Cao <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: Isotr0py <[email protected]>

* [VLM] Initialize video input support for InternVL models (vllm-project#18499)

Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>

* Speed up the `kernels/quantization/` tests (vllm-project#18669)

Signed-off-by: mgoin <[email protected]>

* [BUGFIX] catch subclass first for try...except (vllm-project#18672)

Signed-off-by: Andy Xie <[email protected]>

* [Misc] Reduce logs on startup (vllm-project#18649)

Signed-off-by: DarkLight1337 <[email protected]>

* [doc] fix broken links (vllm-project#18671)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [doc] improve readability (vllm-project#18675)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [Bugfix] Fix cpu usage and cache hit stats reporting on cpu environment (vllm-project#18674)

Signed-off-by: zzzyq <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>

* [CI/build] fix no regex (vllm-project#18676)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [Misc] small improve (vllm-project#18680)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [Bugfix] Fix profiling dummy data for Pixtral (vllm-project#18677)

Signed-off-by: DarkLight1337 <[email protected]>

* [Core][Multimodal] Convert PIL Image to array without data copy when hashing (vllm-project#18682)

Signed-off-by: Lukas Geiger <[email protected]>

* [CI/Build][Doc] Update `gte-Qwen2-1.5B-instruct` usage (vllm-project#18683)

Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: Isotr0py <[email protected]>

* [Misc] Fixed the abnormally high TTFT issue in the PD disaggregation example (vllm-project#18644)

Signed-off-by: zhaohaidao <[email protected]>
Signed-off-by: zhaohaiyuan <[email protected]>
Co-authored-by: zhaohaiyuan <[email protected]>

* refactor: simplify request handler, use positive condition check for handler assignment (vllm-project#18690)

Signed-off-by: googs1025 <[email protected]>

* [Bugfix] Fix the lm_head in gpt_bigcode in lora mode (vllm-project#6357)

Signed-off-by: Max de Bayser <[email protected]>
Signed-off-by: Max de Bayser <[email protected]>

* [CI] add missing argument (vllm-project#18694)

Signed-off-by: Andy Xie <[email protected]>

* [GH] Add issue template for reporting CI failures (vllm-project#18696)

Signed-off-by: DarkLight1337 <[email protected]>

* [Doc] Fix issue template format (vllm-project#18699)

Signed-off-by: DarkLight1337 <[email protected]>

* [Bugfix] Fix Mistral-format models with sliding window (vllm-project#18693)

Signed-off-by: DarkLight1337 <[email protected]>

* [CI/Build] Replace `math.isclose` with `pytest.approx` (vllm-project#18703)

Signed-off-by: DarkLight1337 <[email protected]>

* [CI] fix dump_input for str type (vllm-project#18697)

Signed-off-by: Andy Xie <[email protected]>

* [Model] Add support for YARN in NemotronNAS models (vllm-project#18427)

Signed-off-by: Nave Assaf <[email protected]>

* [CI/Build] Split pooling and generation extended language models tests in CI (vllm-project#18705)

Signed-off-by: Isotr0py <[email protected]>

* [Hardware][Intel-Gaudi] [CI/Build] Add tensor parallel size = 2 test to HPU CI (vllm-project#18709)

Signed-off-by: Lukasz Durejko <[email protected]>

* [Misc] add AutoGen integration (vllm-project#18712)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>

* [Bugfix]: handle hf-xet CAS error when loading Qwen3 weights in vLLM (vllm-project#18701)

* [Doc] Improve API docs (vllm-project#18713)

Signed-off-by: DarkLight1337 <[email protected]>

* [Doc] Move examples and further reorganize user guide (vllm-project#18666)

Signed-off-by: DarkLight1337 <[email protected]>

* [Bugfix] Fix Llama GGUF initialization (vllm-project#18717)

Signed-off-by: DarkLight1337 <[email protected]>

* [V1][Sampler] Improve performance of FlashInfer sampling by sampling logits instead of probs (vllm-project#18608)

* Convert `examples` to `ruff-format` (vllm-project#18400)

Signed-off-by: Harry Mellor <[email protected]>

* [Model][Gemma3] Simplify image input validation (vllm-project#18710)

Signed-off-by: Lukas Geiger <[email protected]>

* [Misc] improve web section group title display (vllm-project#18684)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [V1][Quantization] Add CUDA graph compatible v1 GGUF support (vllm-project#18646)

Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>

* [Model][Gemma3] Cast image pixel values already on CPU (vllm-project#18732)

Signed-off-by: Lukas Geiger <[email protected]>

* [FEAT] [ROCm] Upgrade AITER Fused MoE kernels. (vllm-project#18271)

Signed-off-by: vllmellm <[email protected]>

* [Doc] Update OOT model docs (vllm-project#18742)

Signed-off-by: DarkLight1337 <[email protected]>

* [Doc] Update reproducibility doc and example (vllm-project#18741)

Signed-off-by: DarkLight1337 <[email protected]>

* [Misc] improve docs (vllm-project#18734)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* feat(rocm-support): support mamba2 on rocm (vllm-project#18565)

Signed-off-by: Islam Almersawi <[email protected]>
Co-authored-by: Islam Almersawi <[email protected]>

* [Hardware][Intel-Gaudi] [CI/Build] Fix multiple containers using the same name in run-hpu-test.sh (vllm-project#18752)

Signed-off-by: Lukasz Durejko <[email protected]>

* [Doc] cleanup deprecated flag for doc (vllm-project#18715)

Signed-off-by: calvin chen <[email protected]>

* Minor fix about MooncakeStoreConnector (vllm-project#18721)

Signed-off-by: baoloongmao <[email protected]>

* [Build] fix cpu build missing libtbbmalloc.so (vllm-project#18744)

Signed-off-by: Kebe <[email protected]>

* [BUG FIX] minicpm (vllm-project#18739)

Signed-off-by: huangyuxiang03 <[email protected]>
Co-authored-by: huangyuxiang03 <[email protected]>

* [Doc]  Convert Sphinx directives ( `{class}`, `{meth}`, `{attr}`, ...) to MkDocs format for better documentation linking (vllm-project#18663)

Signed-off-by: Zerohertz <[email protected]>

* [CI/Build] Remove imports of built-in `re` (vllm-project#18750)

Signed-off-by: DarkLight1337 <[email protected]>

* [V1][Metrics] Add API for accessing in-memory Prometheus metrics (vllm-project#17010)

Signed-off-by: Mark McLoughlin <[email protected]>

* Disable prefix cache by default for benchmark (vllm-project#18639)

Signed-off-by: cascade812 <[email protected]>

* optimize get_kv_cache_torch_dtype (vllm-project#18531)

Signed-off-by: idellzheng <[email protected]>

* [Core] Automatically cast multi-modal input dtype (vllm-project#18756)

Signed-off-by: DarkLight1337 <[email protected]>

* [Bugfix] Mistral tool calling when content is list (vllm-project#18729)

Signed-off-by: mgoin <[email protected]>

---------

Signed-off-by: Satyajith Chilappagari <[email protected]>
Signed-off-by: Lucia Fang <[email protected]>
Signed-off-by: Liangfu Chen <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Nan2018 <[email protected]>
Signed-off-by: rand-fly <[email protected]>
Signed-off-by: reidliu41 <[email protected]>
Signed-off-by: Jee Jee Li <[email protected]>
Signed-off-by: 汪志鹏 <[email protected]>
Signed-off-by: mgoin <[email protected]>
Signed-off-by: calvin chen <[email protected]>
Signed-off-by: haochengxia <[email protected]>
Signed-off-by: Dilip Gowda Bhagavan <[email protected]>
Signed-off-by: Michael Goin <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: wwl2755 <[email protected]>
Signed-off-by: nicklucche <[email protected]>
Signed-off-by: Kebe <[email protected]>
Signed-off-by: Yong Hoon Shin <[email protected]>
Signed-off-by: rabi <[email protected]>
Signed-off-by: dhia.rhaiem <[email protected]>
Signed-off-by: giantcroc <[email protected]>
Signed-off-by: Hosang Yoon <[email protected]>
Signed-off-by: Mark McLoughlin <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: Sebastian Schönnenbeck <[email protected]>
Signed-off-by: Andy Xie <[email protected]>
Signed-off-by: Russell Bryant <[email protected]>
Signed-off-by: jaycha <[email protected]>
Signed-off-by: Nick Hill <[email protected]>
Signed-off-by: Shane A <[email protected]>
Signed-off-by: Elaine Zhao <[email protected]>
Signed-off-by: Linkun <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: googs1025 <[email protected]>
Signed-off-by: Bowen Wang <[email protected]>
Signed-off-by: jiang.li <[email protected]>
Signed-off-by: Lukas Geiger <[email protected]>
Signed-off-by: David Xia <[email protected]>
Signed-off-by: wangxiyuan <[email protected]>
Signed-off-by: Mengqing Cao <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Kai Wu <[email protected]>
Signed-off-by: Sanger Steel <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Chenheli Hua <[email protected]>
Signed-off-by: Linkun Chen <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Teruaki Ishizaki <[email protected]>
Signed-off-by: shen-shanshan <[email protected]>
Signed-off-by: Ronald Xu <[email protected]>
Signed-off-by: cascade812 <[email protected]>
Signed-off-by: Yuqi Zhang <[email protected]>
Signed-off-by: Madeesh Kannan <[email protected]>
Signed-off-by: Kay Yan <[email protected]>
Signed-off-by: Zerohertz <[email protected]>
Signed-off-by: Tristan Leclercq <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Rui Qiao <[email protected]>
Signed-off-by: YaoJiayi <[email protected]>
Signed-off-by: Huy Do <[email protected]>
Signed-off-by: Pavani Majety <[email protected]>
Signed-off-by: Crucifixion-Fxl <[email protected]>
Signed-off-by: [email protected] <[email protected]>
Signed-off-by: Mathieu Bordere <[email protected]>
Signed-off-by: wenhuach21 <[email protected]>
Signed-off-by: qizixi <[email protected]>
Signed-off-by: zt2370 <[email protected]>
Signed-off-by: Aaron Pham <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Seiji Eicher <[email protected]>
Signed-off-by: noemotiovon <[email protected]>
Signed-off-by: zzzyq <[email protected]>
Signed-off-by: zhaohaidao <[email protected]>
Signed-off-by: zhaohaiyuan <[email protected]>
Signed-off-by: Max de Bayser <[email protected]>
Signed-off-by: Max de Bayser <[email protected]>
Signed-off-by: Nave Assaf <[email protected]>
Signed-off-by: Lukasz Durejko <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Islam Almersawi <[email protected]>
Signed-off-by: baoloongmao <[email protected]>
Signed-off-by: huangyuxiang03 <[email protected]>
Signed-off-by: idellzheng <[email protected]>
Co-authored-by: sunyicode0012 <[email protected]>
Co-authored-by: Gong Shufan <[email protected]>
Co-authored-by: Satyajith Chilappagari <[email protected]>
Co-authored-by: Lucia Fang <[email protected]>
Co-authored-by: Lucia (Lu) Fang <[email protected]>
Co-authored-by: Liangfu Chen <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
Co-authored-by: Nan Qin <[email protected]>
Co-authored-by: Andrew Sansom <[email protected]>
Co-authored-by: Kevin H. Luu <[email protected]>
Co-authored-by: Random Fly <[email protected]>
Co-authored-by: Reid <[email protected]>
Co-authored-by: reidliu41 <[email protected]>
Co-authored-by: Jee Jee Li <[email protected]>
Co-authored-by: 汪志鹏 <[email protected]>
Co-authored-by: wang.yuqi <[email protected]>
Co-authored-by: 燃 <[email protected]>
Co-authored-by: 松灵 <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Calvin Chen <[email protected]>
Co-authored-by: Percy <[email protected]>
Co-authored-by: Dilip Gowda Bhagavan <[email protected]>
Co-authored-by: bnellnm <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: wwl2755 <[email protected]>
Co-authored-by: Nicolò Lucchesi <[email protected]>
Co-authored-by: Kebe <[email protected]>
Co-authored-by: Yong Hoon Shin <[email protected]>
Co-authored-by: Rabi Mishra <[email protected]>
Co-authored-by: Dhia Eddine Rhaiem <[email protected]>
Co-authored-by: younesbelkada <[email protected]>
Co-authored-by: Ilyas Chahed <[email protected]>
Co-authored-by: Jingwei Zuo <[email protected]>
Co-authored-by: GiantCroc <[email protected]>
Co-authored-by: Hyogeun Oh (오효근) <[email protected]>
Co-authored-by: Hosang <[email protected]>
Co-authored-by: Mark McLoughlin <[email protected]>
Co-authored-by: vllmellm <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
Co-authored-by: Sebastian Schoennenbeck <[email protected]>
Co-authored-by: Ning Xie <[email protected]>
Co-authored-by: Russell Bryant <[email protected]>
Co-authored-by: youngrok cha <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Co-authored-by: kourosh hakhamaneshi <[email protected]>
Co-authored-by: Shane A <[email protected]>
Co-authored-by: aws-elaineyz <[email protected]>
Co-authored-by: Shashwat Srijan <[email protected]>
Co-authored-by: Aakash Shetty <[email protected]>
Co-authored-by: Tailin Pan <[email protected]>
Co-authored-by: Rishabh Rajesh <[email protected]>
Co-authored-by: Yishan McNabb <[email protected]>
Co-authored-by: Patrick Lange <[email protected]>
Co-authored-by: Maxwell Goldberg <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: lkchen <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>
Co-authored-by: CYJiang <[email protected]>
Co-authored-by: Bowen Wang <[email protected]>
Co-authored-by: Li, Jiang <[email protected]>
Co-authored-by: Lukas Geiger <[email protected]>
Co-authored-by: David Xia <[email protected]>
Co-authored-by: wangxiyuan <[email protected]>
Co-authored-by: Mengqing Cao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Ekagra Ranjan <[email protected]>
Co-authored-by: Kai Wu <[email protected]>
Co-authored-by: Sanger Steel <[email protected]>
Co-authored-by: rasmith <[email protected]>
Co-authored-by: Chenheli Hua <[email protected]>
Co-authored-by: Benjamin Chislett <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>
Co-authored-by: Teruaki Ishizaki <[email protected]>
Co-authored-by: Shanshan Shen <[email protected]>
Co-authored-by: RonaldBXu <[email protected]>
Co-authored-by: cascade <[email protected]>
Co-authored-by: Chauncey <[email protected]>
Co-authored-by: simon-mo <[email protected]>
Co-authored-by: Yuqi Zhang <[email protected]>
Co-authored-by: Yuqi Zhang <[email protected]>
Co-authored-by: Madeesh Kannan <[email protected]>
Co-authored-by: Kay Yan <[email protected]>
Co-authored-by: Tristan Leclercq <[email protected]>
Co-authored-by: Simon Mo <[email protected]>
Co-authored-by: Chen Zhang <[email protected]>
Co-authored-by: Jiayi Yao <[email protected]>
Co-authored-by: Rui Qiao <[email protected]>
Co-authored-by: Huy Do <[email protected]>
Co-authored-by: Pavani Majety <[email protected]>
Co-authored-by: Feng XiaoLong <[email protected]>
Co-authored-by: Crucifixion-Fxl <[email protected]>
Co-authored-by: Robert Shaw <[email protected]>
Co-authored-by: Mathieu Borderé <[email protected]>
Co-authored-by: Wenhua Cheng <[email protected]>
Co-authored-by: qizixi <[email protected]>
Co-authored-by: Yuanhao WU <[email protected]>
Co-authored-by: ztang2370 <[email protected]>
Co-authored-by: Aaron Pham <[email protected]>
Co-authored-by: Seiji Eicher <[email protected]>
Co-authored-by: Chenguang Li <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
Co-authored-by: AlexZhao <[email protected]>
Co-authored-by: zhaohaiyuan <[email protected]>
Co-authored-by: Maximilien de Bayser <[email protected]>
Co-authored-by: Naveassaf <[email protected]>
Co-authored-by: Łukasz Durejko <[email protected]>
Co-authored-by: dylan <[email protected]>
Co-authored-by: almersawi <[email protected]>
Co-authored-by: Islam Almersawi <[email protected]>
Co-authored-by: Łukasz Durejko <[email protected]>
Co-authored-by: maobaolong <[email protected]>
Co-authored-by: Shawn Huang <[email protected]>
Co-authored-by: huangyuxiang03 <[email protected]>
Co-authored-by: chunxiaozheng <[email protected]>
@handsome-chips
Copy link

handsome-chips commented May 28, 2025

@YaoJiayi Hi, Could you please tell me how to use this feature? Is this feature compatible with pipeline parallelism?

Can it be used simply by configuring --speculative-config='{"method": "deepseek_mtp", "num_speculative_tokens": 1}' in the vLLM serve parameters like this?
Is there no need to export the MTP model separately?

    vllm serve deepseek-ai/DeepSeek-R1  \
        --max-num-seqs=80 \
        --max-model-len=8192 \
        --max-num-batched-tokens=16384 \
        --tensor-parallel-size 8 \
        --pipeline-parallel-size 2 \
        --enable-expert-parallel \
        --enable-chunked-prefill \
        --enable-prefix-caching \
        --disable-log-requests \
        --distributed-executor-backend ray \
        --swap-space=64 \
        --enable-reasoning \
        --reasoning-parser deepseek_r1 \
        --trust-remote-code \
        --served-model-name deepseek-r1 \
        --speculative-config='{"method": "deepseek_mtp", "num_speculative_tokens": 1}'

Thanks~

@YaoJiayi
Copy link
Contributor Author

@YaoJiayi Hi, Could you please tell me how to use this feature? Is this feature compatible with pipeline parallelism?

Can it be used simply by configuring --speculative-config='{"method": "deepseek_mtp", "num_speculative_tokens": 1}' in the vLLM serve parameters like this? Is there no need to export the MTP model separately?

    vllm serve deepseek-ai/DeepSeek-R1  \
        --max-num-seqs=80 \
        --max-model-len=8192 \
        --max-num-batched-tokens=16384 \
        --tensor-parallel-size 8 \
        --pipeline-parallel-size 2 \
        --enable-expert-parallel \
        --enable-chunked-prefill \
        --enable-prefix-caching \
        --disable-log-requests \
        --distributed-executor-backend ray \
        --swap-space=64 \
        --enable-reasoning \
        --reasoning-parser deepseek_r1 \
        --trust-remote-code \
        --served-model-name deepseek-r1 \
        --speculative-config='{"method": "deepseek_mtp", "num_speculative_tokens": 1}'

Thanks~

PP should be supported. And yes, there's no need to import mtp module separately. The deepseek model weights contain the mtp layer itself

@handsome-chips
Copy link

@YaoJiayi Hi, Could you please tell me how to use this feature? Is this feature compatible with pipeline parallelism?
Can it be used simply by configuring --speculative-config='{"method": "deepseek_mtp", "num_speculative_tokens": 1}' in the vLLM serve parameters like this? Is there no need to export the MTP model separately?

    vllm serve deepseek-ai/DeepSeek-R1  \
        --max-num-seqs=80 \
        --max-model-len=8192 \
        --max-num-batched-tokens=16384 \
        --tensor-parallel-size 8 \
        --pipeline-parallel-size 2 \
        --enable-expert-parallel \
        --enable-chunked-prefill \
        --enable-prefix-caching \
        --disable-log-requests \
        --distributed-executor-backend ray \
        --swap-space=64 \
        --enable-reasoning \
        --reasoning-parser deepseek_r1 \
        --trust-remote-code \
        --served-model-name deepseek-r1 \
        --speculative-config='{"method": "deepseek_mtp", "num_speculative_tokens": 1}'

Thanks~

PP should be supported. And yes, there's no need to import mtp module separately. The deepseek model weights contain the mtp layer itself

But I got error AttributeError: 'GPUModelRunner' object has no attribute 'drafter'^, which use vllm/vllm-openai:v0.9.0 docker image.

025-05-29T12:13:57.588522996+08:00 ^[[36m(RayWorkerWrapper pid=1024)^[[0m ERROR 05-29 12:13:57 [worker_base.py:620] Error executing method 'determine_available_memory'. This might cause deadlock
2025-05-29T12:13:57.588527790+08:00 ^[[36m(RayWorkerWrapper pid=1024)^[[0m ERROR 05-29 12:13:57 [worker_base.py:620] Traceback (most recent call last):
2025-05-29T12:13:57.588532997+08:00 ^[[36m(RayWorkerWrapper pid=1024)^[[0m ERROR 05-29 12:13:57 [worker_base.py:620]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", li
2025-05-29T12:13:57.588537808+08:00 ^[[36m(RayWorkerWrapper pid=1024)^[[0m ERROR 05-29 12:13:57 [worker_base.py:620]     return run_method(self, method, args, kwargs)
2025-05-29T12:13:57.588549892+08:00 ^[[36m(RayWorkerWrapper pid=1024)^[[0m ERROR 05-29 12:13:57 [worker_base.py:620]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.588554744+08:00 ^[[36m(RayWorkerWrapper pid=1024)^[[0m ERROR 05-29 12:13:57 [worker_base.py:620]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2605, in r
2025-05-29T12:13:57.588559691+08:00 ^[[36m(RayWorkerWrapper pid=1024)^[[0m ERROR 05-29 12:13:57 [worker_base.py:620]     return func(*args, **kwargs)
2025-05-29T12:13:57.588564615+08:00 ^[[36m(RayWorkerWrapper pid=1024)^[[0m ERROR 05-29 12:13:57 [worker_base.py:620]            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.588569423+08:00 ^[[36m(RayWorkerWrapper pid=1024)^[[0m ERROR 05-29 12:13:57 [worker_base.py:620]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", li
2025-05-29T12:13:57.588574175+08:00 ^[[36m(RayWorkerWrapper pid=1024)^[[0m ERROR 05-29 12:13:57 [worker_base.py:620]     return func(*args, **kwargs)
2025-05-29T12:13:57.588579216+08:00 ^[[36m(RayWorkerWrapper pid=1024)^[[0m ERROR 05-29 12:13:57 [worker_base.py:620]            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.588584871+08:00 ^[[36m(RayWorkerWrapper pid=1024)^[[0m ERROR 05-29 12:13:57 [worker_base.py:620]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py",
2025-05-29T12:13:57.588589687+08:00 ^[[36m(RayWorkerWrapper pid=1024)^[[0m ERROR 05-29 12:13:57 [worker_base.py:620]     self.model_runner.profile_run()
2025-05-29T12:13:57.588605486+08:00 ^[[36m(RayWorkerWrapper pid=1024)^[[0m ERROR 05-29 12:13:57 [worker_base.py:620]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner
2025-05-29T12:13:57.588610630+08:00 ^[[36m(RayWorkerWrapper pid=1024)^[[0m ERROR 05-29 12:13:57 [worker_base.py:620]     hidden_states = self._dummy_run(self.max_num_tokens)
2025-05-29T12:13:57.588616214+08:00 ^[[36m(RayWorkerWrapper pid=1024)^[[0m ERROR 05-29 12:13:57 [worker_base.py:620]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.591285563+08:00 ^[[36m(RayWorkerWrapper pid=1024)^[[0m ERROR 05-29 12:13:57 [worker_base.py:620]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", li
2025-05-29T12:13:57.591293605+08:00 ^[[36m(RayWorkerWrapper pid=1024)^[[0m ERROR 05-29 12:13:57 [worker_base.py:620]     return func(*args, **kwargs)
2025-05-29T12:13:57.591299061+08:00 ^[[36m(RayWorkerWrapper pid=1024)^[[0m ERROR 05-29 12:13:57 [worker_base.py:620]            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.591304189+08:00 ^[[36m(RayWorkerWrapper pid=1024)^[[0m ERROR 05-29 12:13:57 [worker_base.py:620]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner
2025-05-29T12:13:57.591309957+08:00 ^[[36m(RayWorkerWrapper pid=1024)^[[0m ERROR 05-29 12:13:57 [worker_base.py:620]     assert isinstance(self.drafter, EagleProposer)
2025-05-29T12:13:57.591315505+08:00 ^[[36m(RayWorkerWrapper pid=1024)^[[0m ERROR 05-29 12:13:57 [worker_base.py:620]                       ^^^^^^^^^^^^
2025-05-29T12:13:57.591321524+08:00 ^[[36m(RayWorkerWrapper pid=1024)^[[0m ERROR 05-29 12:13:57 [worker_base.py:620] AttributeError: 'GPUModelRunner' object has no attribute 'drafter'
2025-05-29T12:13:57.591327328+08:00 ERROR 05-29 12:13:57 [core.py:500] EngineCore failed to start.
2025-05-29T12:13:57.591334275+08:00 ERROR 05-29 12:13:57 [core.py:500] Traceback (most recent call last):
2025-05-29T12:13:57.591339733+08:00 ERROR 05-29 12:13:57 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 491, in run_engine_core
2025-05-29T12:13:57.591344574+08:00 ERROR 05-29 12:13:57 [core.py:500]     engine_core = EngineCoreProc(*args, **kwargs)
2025-05-29T12:13:57.591355508+08:00 ERROR 05-29 12:13:57 [core.py:500]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.591360667+08:00 ERROR 05-29 12:13:57 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 390, in __init__
2025-05-29T12:13:57.591365460+08:00 ERROR 05-29 12:13:57 [core.py:500]     super().__init__(vllm_config, executor_class, log_stats,
2025-05-29T12:13:57.591370143+08:00 ERROR 05-29 12:13:57 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 78, in __init__
2025-05-29T12:13:57.591375203+08:00 ERROR 05-29 12:13:57 [core.py:500]     self._initialize_kv_caches(vllm_config)
2025-05-29T12:13:57.591379975+08:00 ERROR 05-29 12:13:57 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 137, in _initialize_kv_caches
2025-05-29T12:13:57.591384939+08:00 ERROR 05-29 12:13:57 [core.py:500]     available_gpu_memory = self.model_executor.determine_available_memory()
2025-05-29T12:13:57.591389521+08:00 ERROR 05-29 12:13:57 [core.py:500]                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.591394666+08:00 ERROR 05-29 12:13:57 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/abstract.py", line 75, in determine_available_memory
2025-05-29T12:13:57.591399473+08:00 ERROR 05-29 12:13:57 [core.py:500]     output = self.collective_rpc("determine_available_memory")
2025-05-29T12:13:57.591404495+08:00 ERROR 05-29 12:13:57 [core.py:500]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.591409200+08:00 ERROR 05-29 12:13:57 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/executor_base.py", line 331, in collective_rpc
2025-05-29T12:13:57.591414538+08:00 ERROR 05-29 12:13:57 [core.py:500]     return self._run_workers(method, *args, **(kwargs or {}))
2025-05-29T12:13:57.591420322+08:00 ERROR 05-29 12:13:57 [core.py:500]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.591424970+08:00 ERROR 05-29 12:13:57 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/ray_distributed_executor.py", line 521, in _run_workers
2025-05-29T12:13:57.591429786+08:00 ERROR 05-29 12:13:57 [core.py:500]     ray_worker_outputs = ray.get(ray_worker_outputs)
2025-05-29T12:13:57.591434222+08:00 ERROR 05-29 12:13:57 [core.py:500]                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.591439178+08:00 ERROR 05-29 12:13:57 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
2025-05-29T12:13:57.591444878+08:00 ERROR 05-29 12:13:57 [core.py:500]     return fn(*args, **kwargs)
2025-05-29T12:13:57.591449410+08:00 ERROR 05-29 12:13:57 [core.py:500]            ^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.591454210+08:00 ERROR 05-29 12:13:57 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
2025-05-29T12:13:57.591459385+08:00 ERROR 05-29 12:13:57 [core.py:500]     return func(*args, **kwargs)
2025-05-29T12:13:57.591464533+08:00 ERROR 05-29 12:13:57 [core.py:500]            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.591469293+08:00 ERROR 05-29 12:13:57 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/ray/_private/worker.py", line 2822, in get
2025-05-29T12:13:57.591474195+08:00 ERROR 05-29 12:13:57 [core.py:500]     values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
2025-05-29T12:13:57.591488500+08:00 ERROR 05-29 12:13:57 [core.py:500]                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.591493791+08:00 ERROR 05-29 12:13:57 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/ray/_private/worker.py", line 930, in get_objects
2025-05-29T12:13:57.591504489+08:00 ERROR 05-29 12:13:57 [core.py:500]     raise value.as_instanceof_cause()
2025-05-29T12:13:57.591509720+08:00 ERROR 05-29 12:13:57 [core.py:500] ray.exceptions.RayTaskError(AttributeError): ^[[36mray::RayWorkerWrapper.execute_method()^[[39m (pid=1024, ip=172.20.91.178,
2025-05-29T12:13:57.591514680+08:00 ERROR 05-29 12:13:57 [core.py:500]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.591519756+08:00 ERROR 05-29 12:13:57 [core.py:500]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.591524496+08:00 ERROR 05-29 12:13:57 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 621, in execute_method
2025-05-29T12:13:57.591530368+08:00 ERROR 05-29 12:13:57 [core.py:500]     raise e
2025-05-29T12:13:57.591535183+08:00 ERROR 05-29 12:13:57 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 612, in execute_method
2025-05-29T12:13:57.591539943+08:00 ERROR 05-29 12:13:57 [core.py:500]     return run_method(self, method, args, kwargs)
2025-05-29T12:13:57.591544724+08:00 ERROR 05-29 12:13:57 [core.py:500]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.591549499+08:00 ERROR 05-29 12:13:57 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2605, in run_method
2025-05-29T12:13:57.591554189+08:00 ERROR 05-29 12:13:57 [core.py:500]     return func(*args, **kwargs)
2025-05-29T12:13:57.591559063+08:00 ERROR 05-29 12:13:57 [core.py:500]            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.591563847+08:00 ERROR 05-29 12:13:57 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
2025-05-29T12:13:57.591568729+08:00 ERROR 05-29 12:13:57 [core.py:500]     return func(*args, **kwargs)
2025-05-29T12:13:57.591573414+08:00 ERROR 05-29 12:13:57 [core.py:500]            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.591578069+08:00 ERROR 05-29 12:13:57 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 185, in determine_available_memory
2025-05-29T12:13:57.591583058+08:00 ERROR 05-29 12:13:57 [core.py:500]     self.model_runner.profile_run()
2025-05-29T12:13:57.591587810+08:00 ERROR 05-29 12:13:57 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1897, in profile_run
2025-05-29T12:13:57.591592674+08:00 ERROR 05-29 12:13:57 [core.py:500]     hidden_states = self._dummy_run(self.max_num_tokens)
2025-05-29T12:13:57.591597298+08:00 ERROR 05-29 12:13:57 [core.py:500]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.591602346+08:00 ERROR 05-29 12:13:57 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
2025-05-29T12:13:57.591607406+08:00 ERROR 05-29 12:13:57 [core.py:500]     return func(*args, **kwargs)
2025-05-29T12:13:57.591612166+08:00 ERROR 05-29 12:13:57 [core.py:500]            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.591616595+08:00 ERROR 05-29 12:13:57 [core.py:500]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1744, in _dummy_run
2025-05-29T12:13:57.591622957+08:00 ERROR 05-29 12:13:57 [core.py:500]     assert isinstance(self.drafter, EagleProposer)
2025-05-29T12:13:57.591627701+08:00 ERROR 05-29 12:13:57 [core.py:500]                       ^^^^^^^^^^^^
2025-05-29T12:13:57.591632667+08:00 ERROR 05-29 12:13:57 [core.py:500] AttributeError: 'GPUModelRunner' object has no attribute 'drafter'
2025-05-29T12:13:57.591975948+08:00 Process EngineCore_0:
2025-05-29T12:13:57.593798795+08:00 Traceback (most recent call last):
2025-05-29T12:13:57.594594089+08:00   File "/usr/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
2025-05-29T12:13:57.594601413+08:00     self.run()
2025-05-29T12:13:57.594607309+08:00   File "/usr/lib/python3.12/multiprocessing/process.py", line 108, in run
2025-05-29T12:13:57.594612477+08:00     self._target(*self._args, **self._kwargs)
2025-05-29T12:13:57.594684775+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 504, in run_engine_core
2025-05-29T12:13:57.594715490+08:00     raise e
2025-05-29T12:13:57.595203986+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 491, in run_engine_core
2025-05-29T12:13:57.595215486+08:00     engine_core = EngineCoreProc(*args, **kwargs)
2025-05-29T12:13:57.595220354+08:00                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.595226738+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 390, in __init__
2025-05-29T12:13:57.595231680+08:00     super().__init__(vllm_config, executor_class, log_stats,
2025-05-29T12:13:57.595246137+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 78, in __init__
2025-05-29T12:13:57.595251314+08:00     self._initialize_kv_caches(vllm_config)
2025-05-29T12:13:57.595256293+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 137, in _initialize_kv_caches
2025-05-29T12:13:57.595262209+08:00     available_gpu_memory = self.model_executor.determine_available_memory()
2025-05-29T12:13:57.595267479+08:00                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.595279188+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/abstract.py", line 75, in determine_available_memory
2025-05-29T12:13:57.595284617+08:00     output = self.collective_rpc("determine_available_memory")
2025-05-29T12:13:57.595290016+08:00              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.595300546+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/executor_base.py", line 331, in collective_rpc
2025-05-29T12:13:57.595305444+08:00     return self._run_workers(method, *args, **(kwargs or {}))
2025-05-29T12:13:57.595311025+08:00            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.595315847+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/ray_distributed_executor.py", line 521, in _run_workers
2025-05-29T12:13:57.595320512+08:00     ray_worker_outputs = ray.get(ray_worker_outputs)
2025-05-29T12:13:57.595325247+08:00                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.595335067+08:00   File "/usr/local/lib/python3.12/dist-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
2025-05-29T12:13:57.595341027+08:00     return fn(*args, **kwargs)
2025-05-29T12:13:57.595346425+08:00            ^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.595351406+08:00   File "/usr/local/lib/python3.12/dist-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
2025-05-29T12:13:57.595357326+08:00     return func(*args, **kwargs)
2025-05-29T12:13:57.595362430+08:00            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.595374546+08:00   File "/usr/local/lib/python3.12/dist-packages/ray/_private/worker.py", line 2822, in get
2025-05-29T12:13:57.595379950+08:00     values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
2025-05-29T12:13:57.595385016+08:00                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.595398036+08:00   File "/usr/local/lib/python3.12/dist-packages/ray/_private/worker.py", line 930, in get_objects
2025-05-29T12:13:57.595402849+08:00     raise value.as_instanceof_cause()
2025-05-29T12:13:57.595415493+08:00 ray.exceptions.RayTaskError(AttributeError): ^[[36mray::RayWorkerWrapper.execute_method()^[[39m (pid=1024, ip=172.20.91.178, actor_id=527d362287d5fc8d5ed1643a01
2025-05-29T12:13:57.595420509+08:00            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.595424949+08:00            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.595429653+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 621, in execute_method
2025-05-29T12:13:57.595435028+08:00     raise e
2025-05-29T12:13:57.595439668+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 612, in execute_method
2025-05-29T12:13:57.595444635+08:00     return run_method(self, method, args, kwargs)
2025-05-29T12:13:57.595449329+08:00            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.595453954+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2605, in run_method
2025-05-29T12:13:57.595458834+08:00     return func(*args, **kwargs)
2025-05-29T12:13:57.595463488+08:00            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.595468384+08:00   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
2025-05-29T12:13:57.595473007+08:00     return func(*args, **kwargs)
2025-05-29T12:13:57.595477574+08:00            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.595482348+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 185, in determine_available_memory
2025-05-29T12:13:57.595487088+08:00     self.model_runner.profile_run()
2025-05-29T12:13:57.595492032+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1897, in profile_run
2025-05-29T12:13:57.595496803+08:00     hidden_states = self._dummy_run(self.max_num_tokens)
2025-05-29T12:13:57.595501407+08:00                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.595506781+08:00   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
2025-05-29T12:13:57.595511962+08:00     return func(*args, **kwargs)
2025-05-29T12:13:57.595516572+08:00            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.595521711+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1744, in _dummy_run
2025-05-29T12:13:57.595526426+08:00     assert isinstance(self.drafter, EagleProposer)
2025-05-29T12:13:57.595531067+08:00                       ^^^^^^^^^^^^
2025-05-29T12:13:57.595536347+08:00 AttributeError: 'GPUModelRunner' object has no attribute 'drafter'
2025-05-29T12:13:57.596537258+08:00 2025-05-29 12:13:57,596     ERROR worker.py:421 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ^[[36mray::RayWorkerWrapper.execute_method()
2025-05-29T12:13:57.596548114+08:00            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.596553193+08:00            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.596559937+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 621, in execute_method
2025-05-29T12:13:57.596565809+08:00     raise e
2025-05-29T12:13:57.596570777+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 612, in execute_method
2025-05-29T12:13:57.596575647+08:00     return run_method(self, method, args, kwargs)
2025-05-29T12:13:57.596604554+08:00            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.596610041+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2605, in run_method
2025-05-29T12:13:57.596615000+08:00     return func(*args, **kwargs)
2025-05-29T12:13:57.596620424+08:00            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.596625604+08:00   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
2025-05-29T12:13:57.596630466+08:00     return func(*args, **kwargs)
2025-05-29T12:13:57.596635228+08:00            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.596639975+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 185, in determine_available_memory
2025-05-29T12:13:57.596644987+08:00     self.model_runner.profile_run()
2025-05-29T12:13:57.596650140+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1897, in profile_run
2025-05-29T12:13:57.596655183+08:00     hidden_states = self._dummy_run(self.max_num_tokens)
2025-05-29T12:13:57.596659993+08:00                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.596665295+08:00   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
2025-05-29T12:13:57.596669814+08:00     return func(*args, **kwargs)
2025-05-29T12:13:57.596674754+08:00            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.596679779+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1744, in _dummy_run
2025-05-29T12:13:57.596684525+08:00     assert isinstance(self.drafter, EagleProposer)
2025-05-29T12:13:57.596689314+08:00                       ^^^^^^^^^^^^
2025-05-29T12:13:57.596694670+08:00 AttributeError: 'GPUModelRunner' object has no attribute 'drafter'
2025-05-29T12:13:57.597112916+08:00 2025-05-29 12:13:57,596     ERROR worker.py:421 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ^[[36mray::RayWorkerWrapper.execute_method()
2025-05-29T12:13:57.597125387+08:00            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.597131236+08:00            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.597136868+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 621, in execute_method
2025-05-29T12:13:57.597142612+08:00     raise e
2025-05-29T12:13:57.597147699+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 612, in execute_method
2025-05-29T12:13:57.597152731+08:00     return run_method(self, method, args, kwargs)
2025-05-29T12:13:57.597157853+08:00            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.597163551+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2605, in run_method
2025-05-29T12:13:57.597170215+08:00     return func(*args, **kwargs)
2025-05-29T12:13:57.597175002+08:00            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.597180322+08:00   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
2025-05-29T12:13:57.597184891+08:00     return func(*args, **kwargs)
2025-05-29T12:13:57.597189854+08:00            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.597194906+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 185, in determine_available_memory
2025-05-29T12:13:57.597199526+08:00     self.model_runner.profile_run()
2025-05-29T12:13:57.597204506+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1897, in profile_run
2025-05-29T12:13:57.597219241+08:00     hidden_states = self._dummy_run(self.max_num_tokens)
2025-05-29T12:13:57.597223991+08:00                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.597230911+08:00   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
2025-05-29T12:13:57.597236049+08:00     return func(*args, **kwargs)
2025-05-29T12:13:57.597241193+08:00            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.597245984+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1744, in _dummy_run
2025-05-29T12:13:57.597250649+08:00     assert isinstance(self.drafter, EagleProposer)
2025-05-29T12:13:57.597255616+08:00                       ^^^^^^^^^^^^
2025-05-29T12:13:57.597260689+08:00 AttributeError: 'GPUModelRunner' object has no attribute 'drafter'
2025-05-29T12:13:57.597602408+08:00 2025-05-29 12:13:57,597     ERROR worker.py:421 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ^[[36mray::RayWorkerWrapper.execute_method()
2025-05-29T12:13:57.597610664+08:00            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.597616112+08:00            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.597621375+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 621, in execute_method
2025-05-29T12:13:57.597626682+08:00     raise e
2025-05-29T12:13:57.597631775+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 612, in execute_method
2025-05-29T12:13:57.597636681+08:00     return run_method(self, method, args, kwargs)
2025-05-29T12:13:57.597641383+08:00            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.597646447+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2605, in run_method
2025-05-29T12:13:57.597651414+08:00     return func(*args, **kwargs)
2025-05-29T12:13:57.597656214+08:00            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.597660974+08:00   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
2025-05-29T12:13:57.597665990+08:00     return func(*args, **kwargs)
2025-05-29T12:13:57.597670994+08:00            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.597675975+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 185, in determine_available_memory
2025-05-29T12:13:57.597680761+08:00     self.model_runner.profile_run()
2025-05-29T12:13:57.597685515+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1897, in profile_run
2025-05-29T12:13:57.597690408+08:00     hidden_states = self._dummy_run(self.max_num_tokens)
2025-05-29T12:13:57.597695248+08:00                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.597699926+08:00   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
2025-05-29T12:13:57.597704614+08:00     return func(*args, **kwargs)
2025-05-29T12:13:57.597709527+08:00            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.597714629+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1744, in _dummy_run
2025-05-29T12:13:57.597720845+08:00     assert isinstance(self.drafter, EagleProposer)
2025-05-29T12:13:57.597725699+08:00                       ^^^^^^^^^^^^
2025-05-29T12:13:57.597730745+08:00 AttributeError: 'GPUModelRunner' object has no attribute 'drafter'
2025-05-29T12:13:57.598035854+08:00 2025-05-29 12:13:57,597     ERROR worker.py:421 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ^[[36mray::RayWorkerWrapper.execute_method()
2025-05-29T12:13:57.598047441+08:00            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.598053119+08:00            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.598058813+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 621, in execute_method
2025-05-29T12:13:57.598064353+08:00     raise e
2025-05-29T12:13:57.598069644+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 612, in execute_method
2025-05-29T12:13:57.598074720+08:00     return run_method(self, method, args, kwargs)
2025-05-29T12:13:57.598079473+08:00            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.598084609+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2605, in run_method
2025-05-29T12:13:57.598090270+08:00     return func(*args, **kwargs)
2025-05-29T12:13:57.598095248+08:00            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.598100132+08:00   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
2025-05-29T12:13:57.598105212+08:00     return func(*args, **kwargs)
2025-05-29T12:13:57.598109603+08:00            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.598114310+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 185, in determine_available_memory
2025-05-29T12:13:57.598119468+08:00     self.model_runner.profile_run()
2025-05-29T12:13:57.598124232+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1897, in profile_run
2025-05-29T12:13:57.598129184+08:00     hidden_states = self._dummy_run(self.max_num_tokens)
2025-05-29T12:13:57.598134369+08:00                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.598139231+08:00   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
2025-05-29T12:13:57.598144135+08:00     return func(*args, **kwargs)
2025-05-29T12:13:57.598148847+08:00            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.598153878+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1744, in _dummy_run
2025-05-29T12:13:57.598158655+08:00     assert isinstance(self.drafter, EagleProposer)
2025-05-29T12:13:57.598163402+08:00                       ^^^^^^^^^^^^
2025-05-29T12:13:57.598168458+08:00 AttributeError: 'GPUModelRunner' object has no attribute 'drafter'
2025-05-29T12:13:57.598421272+08:00 2025-05-29 12:13:57,598     ERROR worker.py:421 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ^[[36mray::RayWorkerWrapper.execute_method()
2025-05-29T12:13:57.598427722+08:00            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.598432472+08:00            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.598437268+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 621, in execute_method
2025-05-29T12:13:57.598442580+08:00     raise e
2025-05-29T12:13:57.598447611+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 612, in execute_method
2025-05-29T12:13:57.598452504+08:00     return run_method(self, method, args, kwargs)
2025-05-29T12:13:57.598457500+08:00            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.598468259+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2605, in run_method
2025-05-29T12:13:57.598473043+08:00     return func(*args, **kwargs)
2025-05-29T12:13:57.598477731+08:00            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.598482775+08:00   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
2025-05-29T12:13:57.598487285+08:00     return func(*args, **kwargs)
2025-05-29T12:13:57.598491801+08:00            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.598496710+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 185, in determine_available_memory
2025-05-29T12:13:57.598501602+08:00     self.model_runner.profile_run()
2025-05-29T12:13:57.598506295+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1897, in profile_run
2025-05-29T12:13:57.598510814+08:00     hidden_states = self._dummy_run(self.max_num_tokens)
2025-05-29T12:13:57.598515260+08:00                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.598519914+08:00   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
2025-05-29T12:13:57.598524422+08:00     return func(*args, **kwargs)
2025-05-29T12:13:57.598529529+08:00            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.598534489+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1744, in _dummy_run
2025-05-29T12:13:57.598539505+08:00     assert isinstance(self.drafter, EagleProposer)
2025-05-29T12:13:57.598544599+08:00                       ^^^^^^^^^^^^
2025-05-29T12:13:57.598549333+08:00 AttributeError: 'GPUModelRunner' object has no attribute 'drafter'
2025-05-29T12:13:57.598954455+08:00 2025-05-29 12:13:57,598     ERROR worker.py:421 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ^[[36mray::RayWorkerWrapper.execute_method()
2025-05-29T12:13:57.598962051+08:00            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.598967084+08:00            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.598971847+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 621, in execute_method
2025-05-29T12:13:57.598977192+08:00     raise e
2025-05-29T12:13:57.598982175+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 612, in execute_method
2025-05-29T12:13:57.598987147+08:00     return run_method(self, method, args, kwargs)
2025-05-29T12:13:57.598992302+08:00            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.598997809+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2605, in run_method
2025-05-29T12:13:57.599003786+08:00     return func(*args, **kwargs)
2025-05-29T12:13:57.599008862+08:00            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.599013518+08:00   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
2025-05-29T12:13:57.599018102+08:00     return func(*args, **kwargs)
2025-05-29T12:13:57.599022837+08:00            ^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.599029281+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 185, in determine_available_memory
2025-05-29T12:13:57.599055767+08:00     self.model_runner.profile_run()
2025-05-29T12:13:57.599060481+08:00   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1897, in profile_run
2025-05-29T12:13:57.599065427+08:00     hidden_states = self._dummy_run(self.max_num_tokens)
2025-05-29T12:13:57.599078461+08:00                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-05-29T12:13:57.599083201+08:00   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
2025-05-29T12:13:57.599088076+08:00     return func(*args, **kwargs)

@handsome-chips
Copy link

Is the configuration for MTP spec is speculative_config=SpeculativeConfig(method='deepseek_mtp', model='/deepseek-r1', num_spec_tokens=1), ? The drafter model is /deepseek-r1?

@mahaocong90
Copy link

mahaocong90 commented Jun 4, 2025

@YaoJiayi Hi, I have a problem. My running command is:

    vllm serve /ssd0/model/DeepSeek-R1 \
        --port=8000 --max-model-len=8192 \
        --max-num-batched-tokens=16384 \
        --gpu_memory_utilization 0.9 \
        --tensor-parallel-size 8 \
        --enable-chunked-prefill \
        --enable-prefix-caching \
        --trust-remote-code \
        --served-model-name deepseek-r1 \
        --speculative-config='{"method": "deepseek_mtp", "num_speculative_tokens": 1}'

I got this error 'RuntimeError: Worker failed with error ''GPUModelRunner' object has no attribute 'attn_metadata_builder'', please check the stack trace above for the root cause' when I send request to vllm server.

I run server on H20 x 8, and install vllm == 0.9.0.1.

My request is like this: curl http://127.0.0.1:8000/v1/completions -H "Content-Type: application/json" -d '{ "model": "deepseek-r1", "prompt": "China is", "max_tokens": 30, "temperature": 0, "stream": true }'

INFO 06-04 10:14:43 [async_llm.py:261] Added request cmpl-e8f26a6c69014834ab68bc10a77149f1-0.
(VllmWorker rank=1 pid=407623) ERROR 06-04 10:14:43 [multiproc_executor.py:522] WorkerProc hit an exception.
(VllmWorker rank=1 pid=407623) ERROR 06-04 10:14:43 [multiproc_executor.py:522] Traceback (most recent call last):
(VllmWorker rank=1 pid=407623) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/executor/multiproc_executor.py", line 517, in worker_busy_loop
(VllmWorker rank=1 pid=407623) ERROR 06-04 10:14:43 [multiproc_executor.py:522] output = func(*args, **kwargs)
(VllmWorker rank=1 pid=407623) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=1 pid=407623) ERROR 06-04 10:14:43 [multiproc_executor.py:522] return func(*args, **kwargs)
(VllmWorker rank=1 pid=407623) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/worker/gpu_worker.py", line 276, in execute_model
(VllmWorker rank=1 pid=407623) ERROR 06-04 10:14:43 [multiproc_executor.py:522] output = self.model_runner.execute_model(scheduler_output,
(VllmWorker rank=1 pid=407623) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=1 pid=407623) ERROR 06-04 10:14:43 [multiproc_executor.py:522] return func(*args, **kwargs)
(VllmWorker rank=1 pid=407623) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1419, in execute_model
(VllmWorker rank=1 pid=407623) ERROR 06-04 10:14:43 [multiproc_executor.py:522] draft_token_ids = self.drafter.propose(
(VllmWorker rank=1 pid=407623) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/spec_decode/eagle.py", line 144, in propose
(VllmWorker rank=1 pid=407623) ERROR 06-04 10:14:43 [multiproc_executor.py:522] attn_metadata = self.runner.attn_metadata_builder.build(
(VllmWorker rank=1 pid=407623) ERROR 06-04 10:14:43 [multiproc_executor.py:522] AttributeError: 'GPUModelRunner' object has no attribute 'attn_metadata_builder'. Did you mean: 'attn_metadata_builders'?
(VllmWorker rank=5 pid=407627) ERROR 06-04 10:14:43 [multiproc_executor.py:522] WorkerProc hit an exception.
(VllmWorker rank=5 pid=407627) ERROR 06-04 10:14:43 [multiproc_executor.py:522] Traceback (most recent call last):
(VllmWorker rank=5 pid=407627) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/executor/multiproc_executor.py", line 517, in worker_busy_loop
(VllmWorker rank=5 pid=407627) ERROR 06-04 10:14:43 [multiproc_executor.py:522] output = func(*args, **kwargs)
(VllmWorker rank=2 pid=407624) ERROR 06-04 10:14:43 [multiproc_executor.py:522] WorkerProc hit an exception.
(VllmWorker rank=5 pid=407627) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=2 pid=407624) ERROR 06-04 10:14:43 [multiproc_executor.py:522] Traceback (most recent call last):
(VllmWorker rank=5 pid=407627) ERROR 06-04 10:14:43 [multiproc_executor.py:522] return func(*args, **kwargs)
(VllmWorker rank=2 pid=407624) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/executor/multiproc_executor.py", line 517, in worker_busy_loop
(VllmWorker rank=5 pid=407627) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/worker/gpu_worker.py", line 276, in execute_model
(VllmWorker rank=2 pid=407624) ERROR 06-04 10:14:43 [multiproc_executor.py:522] output = func(*args, **kwargs)
(VllmWorker rank=5 pid=407627) ERROR 06-04 10:14:43 [multiproc_executor.py:522] output = self.model_runner.execute_model(scheduler_output,
(VllmWorker rank=2 pid=407624) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=5 pid=407627) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=2 pid=407624) ERROR 06-04 10:14:43 [multiproc_executor.py:522] return func(*args, **kwargs)
(VllmWorker rank=5 pid=407627) ERROR 06-04 10:14:43 [multiproc_executor.py:522] return func(*args, **kwargs)
(VllmWorker rank=4 pid=407626) ERROR 06-04 10:14:43 [multiproc_executor.py:522] WorkerProc hit an exception.
(VllmWorker rank=2 pid=407624) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/worker/gpu_worker.py", line 276, in execute_model
(VllmWorker rank=5 pid=407627) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1419, in execute_model
(VllmWorker rank=4 pid=407626) ERROR 06-04 10:14:43 [multiproc_executor.py:522] Traceback (most recent call last):
(VllmWorker rank=2 pid=407624) ERROR 06-04 10:14:43 [multiproc_executor.py:522] output = self.model_runner.execute_model(scheduler_output,
(VllmWorker rank=4 pid=407626) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/executor/multiproc_executor.py", line 517, in worker_busy_loop
(VllmWorker rank=5 pid=407627) ERROR 06-04 10:14:43 [multiproc_executor.py:522] draft_token_ids = self.drafter.propose(
(VllmWorker rank=2 pid=407624) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=4 pid=407626) ERROR 06-04 10:14:43 [multiproc_executor.py:522] output = func(*args, **kwargs)
(VllmWorker rank=5 pid=407627) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/spec_decode/eagle.py", line 144, in propose
(VllmWorker rank=4 pid=407626) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=2 pid=407624) ERROR 06-04 10:14:43 [multiproc_executor.py:522] return func(*args, **kwargs)
(VllmWorker rank=5 pid=407627) ERROR 06-04 10:14:43 [multiproc_executor.py:522] attn_metadata = self.runner.attn_metadata_builder.build(
(VllmWorker rank=4 pid=407626) ERROR 06-04 10:14:43 [multiproc_executor.py:522] return func(*args, **kwargs)
(VllmWorker rank=2 pid=407624) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1419, in execute_model
(VllmWorker rank=5 pid=407627) ERROR 06-04 10:14:43 [multiproc_executor.py:522] AttributeError: 'GPUModelRunner' object has no attribute 'attn_metadata_builder'. Did you mean: 'attn_metadata_builders'?
(VllmWorker rank=4 pid=407626) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/worker/gpu_worker.py", line 276, in execute_model
(VllmWorker rank=2 pid=407624) ERROR 06-04 10:14:43 [multiproc_executor.py:522] draft_token_ids = self.drafter.propose(
(VllmWorker rank=4 pid=407626) ERROR 06-04 10:14:43 [multiproc_executor.py:522] output = self.model_runner.execute_model(scheduler_output,
(VllmWorker rank=3 pid=407625) ERROR 06-04 10:14:43 [multiproc_executor.py:522] WorkerProc hit an exception.
(VllmWorker rank=2 pid=407624) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/spec_decode/eagle.py", line 144, in propose
(VllmWorker rank=4 pid=407626) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=2 pid=407624) ERROR 06-04 10:14:43 [multiproc_executor.py:522] attn_metadata = self.runner.attn_metadata_builder.build(
(VllmWorker rank=3 pid=407625) ERROR 06-04 10:14:43 [multiproc_executor.py:522] Traceback (most recent call last):
(VllmWorker rank=4 pid=407626) ERROR 06-04 10:14:43 [multiproc_executor.py:522] return func(*args, **kwargs)
(VllmWorker rank=2 pid=407624) ERROR 06-04 10:14:43 [multiproc_executor.py:522] AttributeError: 'GPUModelRunner' object has no attribute 'attn_metadata_builder'. Did you mean: 'attn_metadata_builders'?
(VllmWorker rank=3 pid=407625) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/executor/multiproc_executor.py", line 517, in worker_busy_loop
(VllmWorker rank=4 pid=407626) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1419, in execute_model
(VllmWorker rank=3 pid=407625) ERROR 06-04 10:14:43 [multiproc_executor.py:522] output = func(*args, **kwargs)
(VllmWorker rank=4 pid=407626) ERROR 06-04 10:14:43 [multiproc_executor.py:522] draft_token_ids = self.drafter.propose(
(VllmWorker rank=3 pid=407625) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=4 pid=407626) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/spec_decode/eagle.py", line 144, in propose
(VllmWorker rank=3 pid=407625) ERROR 06-04 10:14:43 [multiproc_executor.py:522] return func(*args, **kwargs)
(VllmWorker rank=4 pid=407626) ERROR 06-04 10:14:43 [multiproc_executor.py:522] attn_metadata = self.runner.attn_metadata_builder.build(
(VllmWorker rank=3 pid=407625) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/worker/gpu_worker.py", line 276, in execute_model
(VllmWorker rank=4 pid=407626) ERROR 06-04 10:14:43 [multiproc_executor.py:522] AttributeError: 'GPUModelRunner' object has no attribute 'attn_metadata_builder'. Did you mean: 'attn_metadata_builders'?
(VllmWorker rank=3 pid=407625) ERROR 06-04 10:14:43 [multiproc_executor.py:522] output = self.model_runner.execute_model(scheduler_output,
(VllmWorker rank=3 pid=407625) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=3 pid=407625) ERROR 06-04 10:14:43 [multiproc_executor.py:522] return func(*args, **kwargs)
(VllmWorker rank=3 pid=407625) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1419, in execute_model
(VllmWorker rank=3 pid=407625) ERROR 06-04 10:14:43 [multiproc_executor.py:522] draft_token_ids = self.drafter.propose(
(VllmWorker rank=3 pid=407625) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/spec_decode/eagle.py", line 144, in propose
(VllmWorker rank=3 pid=407625) ERROR 06-04 10:14:43 [multiproc_executor.py:522] attn_metadata = self.runner.attn_metadata_builder.build(
(VllmWorker rank=3 pid=407625) ERROR 06-04 10:14:43 [multiproc_executor.py:522] AttributeError: 'GPUModelRunner' object has no attribute 'attn_metadata_builder'. Did you mean: 'attn_metadata_builders'?
(VllmWorker rank=7 pid=407629) ERROR 06-04 10:14:43 [multiproc_executor.py:522] WorkerProc hit an exception.
(VllmWorker rank=0 pid=407622) ERROR 06-04 10:14:43 [multiproc_executor.py:522] WorkerProc hit an exception.
(VllmWorker rank=7 pid=407629) ERROR 06-04 10:14:43 [multiproc_executor.py:522] Traceback (most recent call last):
(VllmWorker rank=0 pid=407622) ERROR 06-04 10:14:43 [multiproc_executor.py:522] Traceback (most recent call last):
(VllmWorker rank=7 pid=407629) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/executor/multiproc_executor.py", line 517, in worker_busy_loop
(VllmWorker rank=0 pid=407622) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/executor/multiproc_executor.py", line 517, in worker_busy_loop
(VllmWorker rank=7 pid=407629) ERROR 06-04 10:14:43 [multiproc_executor.py:522] output = func(*args, **kwargs)
(VllmWorker rank=0 pid=407622) ERROR 06-04 10:14:43 [multiproc_executor.py:522] output = func(*args, **kwargs)
(VllmWorker rank=7 pid=407629) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=0 pid=407622) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=7 pid=407629) ERROR 06-04 10:14:43 [multiproc_executor.py:522] return func(*args, **kwargs)
(VllmWorker rank=0 pid=407622) ERROR 06-04 10:14:43 [multiproc_executor.py:522] return func(*args, **kwargs)
(VllmWorker rank=7 pid=407629) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/worker/gpu_worker.py", line 276, in execute_model
(VllmWorker rank=0 pid=407622) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/worker/gpu_worker.py", line 276, in execute_model
(VllmWorker rank=7 pid=407629) ERROR 06-04 10:14:43 [multiproc_executor.py:522] output = self.model_runner.execute_model(scheduler_output,
(VllmWorker rank=0 pid=407622) ERROR 06-04 10:14:43 [multiproc_executor.py:522] output = self.model_runner.execute_model(scheduler_output,
(VllmWorker rank=7 pid=407629) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=0 pid=407622) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=6 pid=407628) ERROR 06-04 10:14:43 [multiproc_executor.py:522] WorkerProc hit an exception.
(VllmWorker rank=7 pid=407629) ERROR 06-04 10:14:43 [multiproc_executor.py:522] return func(*args, **kwargs)
(VllmWorker rank=0 pid=407622) ERROR 06-04 10:14:43 [multiproc_executor.py:522] return func(*args, **kwargs)
(VllmWorker rank=6 pid=407628) ERROR 06-04 10:14:43 [multiproc_executor.py:522] Traceback (most recent call last):
(VllmWorker rank=7 pid=407629) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1419, in execute_model
(VllmWorker rank=0 pid=407622) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1419, in execute_model
(VllmWorker rank=6 pid=407628) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/executor/multiproc_executor.py", line 517, in worker_busy_loop
(VllmWorker rank=7 pid=407629) ERROR 06-04 10:14:43 [multiproc_executor.py:522] draft_token_ids = self.drafter.propose(
(VllmWorker rank=0 pid=407622) ERROR 06-04 10:14:43 [multiproc_executor.py:522] draft_token_ids = self.drafter.propose(
(VllmWorker rank=6 pid=407628) ERROR 06-04 10:14:43 [multiproc_executor.py:522] output = func(*args, **kwargs)
(VllmWorker rank=7 pid=407629) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/spec_decode/eagle.py", line 144, in propose
(VllmWorker rank=0 pid=407622) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/spec_decode/eagle.py", line 144, in propose
(VllmWorker rank=6 pid=407628) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=7 pid=407629) ERROR 06-04 10:14:43 [multiproc_executor.py:522] attn_metadata = self.runner.attn_metadata_builder.build(
(VllmWorker rank=0 pid=407622) ERROR 06-04 10:14:43 [multiproc_executor.py:522] attn_metadata = self.runner.attn_metadata_builder.build(
(VllmWorker rank=6 pid=407628) ERROR 06-04 10:14:43 [multiproc_executor.py:522] return func(*args, **kwargs)
(VllmWorker rank=7 pid=407629) ERROR 06-04 10:14:43 [multiproc_executor.py:522] AttributeError: 'GPUModelRunner' object has no attribute 'attn_metadata_builder'. Did you mean: 'attn_metadata_builders'?
(VllmWorker rank=0 pid=407622) ERROR 06-04 10:14:43 [multiproc_executor.py:522] AttributeError: 'GPUModelRunner' object has no attribute 'attn_metadata_builder'. Did you mean: 'attn_metadata_builders'?
(VllmWorker rank=6 pid=407628) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/worker/gpu_worker.py", line 276, in execute_model
(VllmWorker rank=6 pid=407628) ERROR 06-04 10:14:43 [multiproc_executor.py:522] output = self.model_runner.execute_model(scheduler_output,
(VllmWorker rank=6 pid=407628) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=6 pid=407628) ERROR 06-04 10:14:43 [multiproc_executor.py:522] return func(*args, **kwargs)
(VllmWorker rank=6 pid=407628) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1419, in execute_model
(VllmWorker rank=6 pid=407628) ERROR 06-04 10:14:43 [multiproc_executor.py:522] draft_token_ids = self.drafter.propose(
(VllmWorker rank=6 pid=407628) ERROR 06-04 10:14:43 [multiproc_executor.py:522] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/spec_decode/eagle.py", line 144, in propose
(VllmWorker rank=6 pid=407628) ERROR 06-04 10:14:43 [multiproc_executor.py:522] attn_metadata = self.runner.attn_metadata_builder.build(
(VllmWorker rank=6 pid=407628) ERROR 06-04 10:14:43 [multiproc_executor.py:522] AttributeError: 'GPUModelRunner' object has no attribute 'attn_metadata_builder'. Did you mean: 'attn_metadata_builders'?
ERROR 06-04 10:14:43 [dump_input.py:68] Dumping input data
ERROR 06-04 10:14:43 [dump_input.py:70] V1 LLM engine (v0.9.0.1) with config: model='/ssd0/model/DeepSeek-R1', speculative_config=SpeculativeConfig(method='deepseek_mtp', model='/ssd0/model/DeepSeek-R1', num_spec_tokens=1), tokenizer='/ssd0/model/DeepSeek-R1', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=auto, tensor_parallel_size=8, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=fp8, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=deepseek-r1, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=False, pooler_config=None, compilation_config={"level": 3, "custom_ops": ["none"], "splitting_ops": ["vllm.unified_attention", "vllm.unified_attention_with_output"], "compile_sizes": [], "inductor_compile_config": {"enable_auto_functionalized_v2": false}, "use_cudagraph": true, "cudagraph_num_of_warmups": 1, "cudagraph_capture_sizes": [512, 504, 496, 488, 480, 472, 464, 456, 448, 440, 432, 424, 416, 408, 400, 392, 384, 376, 368, 360, 352, 344, 336, 328, 320, 312, 304, 296, 288, 280, 272, 264, 256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176, 168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8, 4, 2, 1], "max_capture_size": 512},
ERROR 06-04 10:14:43 [dump_input.py:78] Dumping scheduler output for model execution:
ERROR 06-04 10:14:43 [dump_input.py:79] SchedulerOutput(scheduled_new_reqs=[NewRequestData(req_id=cmpl-e8f26a6c69014834ab68bc10a77149f1-0,prompt_token_ids_len=3,mm_inputs=[],mm_hashes=[],mm_positions=[],sampling_params=SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.0, top_p=1.0, top_k=0, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=30, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, guided_decoding=None, extra_args=None),block_ids=[[1]],num_computed_tokens=0,lora_request=None)], scheduled_cached_reqs=[], num_scheduled_tokens={cmpl-e8f26a6c69014834ab68bc10a77149f1-0: 3}, total_num_scheduled_tokens=3, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, num_common_prefix_blocks=[1], finished_req_ids=[], free_encoder_input_ids=[], structured_output_request_ids={}, grammar_bitmask=null, kv_connector_metadata=null)
ERROR 06-04 10:14:43 [dump_input.py:81] SchedulerStats(num_running_reqs=1, num_waiting_reqs=0, gpu_cache_usage=0.00023164234422057284, prefix_cache_stats=PrefixCacheStats(reset=False, requests=1, queries=3, hits=0), spec_decoding_stats=None)
ERROR 06-04 10:14:43 [core.py:502] EngineCore encountered a fatal error.
ERROR 06-04 10:14:43 [core.py:502] Traceback (most recent call last):
ERROR 06-04 10:14:43 [core.py:502] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/engine/core.py", line 493, in run_engine_core
ERROR 06-04 10:14:43 [core.py:502] engine_core.run_busy_loop()
ERROR 06-04 10:14:43 [core.py:502] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/engine/core.py", line 520, in run_busy_loop
ERROR 06-04 10:14:43 [core.py:502] self._process_engine_step()
ERROR 06-04 10:14:43 [core.py:502] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/engine/core.py", line 545, in _process_engine_step
ERROR 06-04 10:14:43 [core.py:502] outputs = self.step_fn()
ERROR 06-04 10:14:43 [core.py:502] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/engine/core.py", line 226, in step
ERROR 06-04 10:14:43 [core.py:502] model_output = self.execute_model(scheduler_output)
ERROR 06-04 10:14:43 [core.py:502] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/engine/core.py", line 213, in execute_model
ERROR 06-04 10:14:43 [core.py:502] raise err
ERROR 06-04 10:14:43 [core.py:502] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/engine/core.py", line 207, in execute_model
ERROR 06-04 10:14:43 [core.py:502] return self.model_executor.execute_model(scheduler_output)
ERROR 06-04 10:14:43 [core.py:502] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/executor/multiproc_executor.py", line 158, in execute_model
ERROR 06-04 10:14:43 [core.py:502] (output, ) = self.collective_rpc("execute_model",
ERROR 06-04 10:14:43 [core.py:502] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/executor/multiproc_executor.py", line 215, in collective_rpc
ERROR 06-04 10:14:43 [core.py:502] result = get_response(w, dequeue_timeout)
ERROR 06-04 10:14:43 [core.py:502] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/executor/multiproc_executor.py", line 202, in get_response
ERROR 06-04 10:14:43 [core.py:502] raise RuntimeError(
ERROR 06-04 10:14:43 [core.py:502] RuntimeError: Worker failed with error ''GPUModelRunner' object has no attribute 'attn_metadata_builder'', please check the stack trace above for the root cause
ERROR 06-04 10:14:43 [async_llm.py:408] AsyncLLM output_handler failed.
ERROR 06-04 10:14:43 [async_llm.py:408] Traceback (most recent call last):
ERROR 06-04 10:14:43 [async_llm.py:408] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/engine/async_llm.py", line 366, in output_handler
ERROR 06-04 10:14:43 [async_llm.py:408] outputs = await engine_core.get_output_async()
ERROR 06-04 10:14:43 [async_llm.py:408] File "/usr/local/lib/python3.10/dist-packages/vllm/v1/engine/core_client.py", line 806, in get_output_async
ERROR 06-04 10:14:43 [async_llm.py:408] raise self._format_exception(outputs) from None
ERROR 06-04 10:14:43 [async_llm.py:408] vllm.v1.engine.exceptions.EngineDeadError: EngineCore encountered an issue. See stack trace (above) for the root cause.
INFO 06-04 10:14:43 [async_llm.py:333] Request cmpl-e8f26a6c69014834ab68bc10a77149f1-0 failed (engine dead).
INFO: Shutting down
INFO: Waiting for application shutdown.
INFO: Application shutdown complete.
INFO: Finished server process [407269]

Thanks~

@rain7996
Copy link

rain7996 commented Jun 4, 2025

I met the same problem as @mahaocong90. Anybody knows why?

@DiegoD94
Copy link

DiegoD94 commented Jun 4, 2025

Got the same error here, with deepseek mtp + v1 engine, with MLA attention backend

@YaoJiayi
Copy link
Contributor Author

YaoJiayi commented Jun 4, 2025

(VllmWorker rank=5 pid=407627) ERROR 06-04 10:14:43 [multiproc_executor.py:522] AttributeError: 'GPUModelRunner' object has no attribute 'attn_metadata_builder'. Did you mean: 'attn_metadata_builders'?
There might be some other PRs that changed the attn_metadata_builder to attn_metadata_builders.

@DiegoD94
Copy link

DiegoD94 commented Jun 4, 2025

I tried to change the line 144 in eagle.py to self.runner.attn_metadata_builders[0].build, but got another error, does it support MTP size > 1 for now, or only MTP size = 1?

@YaoJiayi
Copy link
Contributor Author

YaoJiayi commented Jun 4, 2025

For now, it should be 1 because the number of MTP layer is 1. I haven't tested MTP size > 1 but it should be easy to support

@DiegoD94
Copy link

DiegoD94 commented Jun 4, 2025

Thanks! I tried with MTP size = 1 and the above dumb fix it will work fine for now, but I guess a formal fix patch to align the syntax is needed. @rain7996 @mahaocong90

@mahaocong90
Copy link

I tried to change the line 144 in eagle.py to self.runner.attn_metadata_builders[0].build, but got another error, does it support MTP size > 1 for now, or only MTP size = 1?

Yes, it works, Thanks ~

@kongweiming
Copy link

在 Deepseek R1 上测试,其中 (1) TP=8 和 (2) TP=4 * PP=2。

待办事项:

  • 统一MTP(MLA)和EAGLE(正常注意)代码路径。
  • 基准性能。
  • 当有更多草稿层时,优化模型层(目标+草稿)的分配(即,当前草稿模型位于最后一个 pp 等级)。

I used TP=8,PP=2 at two machine to start deepseek-r1 mtp, it did not work. I used TP=16, it can work, but the performance is too poor. May I ask you something? The config of TP8PP2 is work? my infer service start code is:
VLLM_ATTENTION_BACKEND=FLASHINFER python3 -m vllm.entrypoints.api_server --model=/models/DeepSeek-R1-0528/ \ --host=127.0.0.1 --port=11234 --trust-remote-code \ --tensor-parallel-size=8 \ --pipeline-parallel-size=2 \ --gpu-memory-utilization=0.95 \ --max-logprobs=20 --swap-space=0 --compilation-config=3 --max-num-seqs=512 \ --distributed-executor-backend=ray \ --no-enable-prefix-caching \ --enable-chunked-prefill \ --max-num-batched-tokens=16384 \ --max-model-len=65536 \ --disable-log-requests \ --served-model-name=deepseek-r1 \ --speculative-config='{"method": "deepseek_mtp", "num_speculative_tokens": 1}'

minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
Signed-off-by: Rui Qiao <[email protected]>
Signed-off-by: YaoJiayi <[email protected]>
Co-authored-by: Rui Qiao <[email protected]>
Signed-off-by: minpeter <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build frontend ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.