Skip to content

Conversation

maxdebayser
Copy link
Contributor

@maxdebayser maxdebayser commented Jul 12, 2024

This is an alternative fix for #6314 that doesn't disable LoRA for the lm_head in GPTBigCode that we worked on today with @tjohnson31415 and @tdoublep .

Due to the weight-tie, in lora mode the lm_head implementation was being replaced with VocabParallelEmbeddingWithLoRA which is not meant for the lm_head. To fix the issue this PR initializes the lm_head as an instance of ParallelLMHead and assigns it the weights of the embedding module. Modules of this class are not substituted during the initialization of LoRa. Due to previous differences in the padding of the two layers (64 vs 256) there was a size mismatch, so the same padding is now applied to both. The resulting vocabulary dimension had to be added to bgmv_config.h

Except for the padding adjustment, this is now basically the same as in the LLama code.

cc @robertgshaw2-neuralmagic

Due to the weight-tie, in lora mode the lm_head implementation
was being replaced with VocabParallelEmbeddingWithLoRA which is not
meant for the lm_head. To fix the issue this commit initializes
the lm_head as an instance of ParallelLMHead and assigns it the
weights of the embedding module. Due to previous differences in the
padding of the two layers (64 vs 256), the same padding is now applied
to both. The resulting vocabulary dimension had to be added to
bgmv_config.h

Signed-off-by: Max de Bayser <[email protected]>
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This init function will call create_weights(), which initializes the buffers for the lm_head

Are we sure these are properly cleaned up when setting self.lm_head.weight?

Additionally, note that self.lm_head.weight = self.transformer.wte.weight will only work for fp16 models since there is a more complicated state dict that needs overriding if the lm-head is quantized

I think this is okay (b/c the lm_head was never allowed to be quantized since we do not support quantized lm head for tied embedding models) ... but nonetheless something that could cause errors in the future to be aware of. So perhaps we could enforce that lm head is not quantized?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've now updated the code to assert that no quantization is used. I've also added a subclass of UnquantizedLinearMethod that doesn't initialize weights because it's wasteful if the weights are replaced in the next step.

@robertgshaw2-redhat
Copy link
Collaborator

This is an alternative fix for #6314 that doesn't disable LoRA for the lm_head in GPTBigCode that we worked on today with @tjohnson31415 and @tdoublep .

Due to the weight-tie, in lora mode the lm_head implementation was being replaced with VocabParallelEmbeddingWithLoRA which is not meant for the lm_head. To fix the issue this PR initializes the lm_head as an instance of ParallelLMHead and assigns it the weights of the embedding module. Modules of this class are not substituted during the initialization of LoRa. Due to previous differences in the padding of the two layers (64 vs 256) there was a size mismatch, so the same padding is now applied to both. The resulting vocabulary dimension had to be added to bgmv_config.h

Except for the padding adjustment, this is now basically the same as in the LLama code.

cc @robertgshaw2-neuralmagic

My (albeit limited) understanding of LoRA is that it is uncommon to train LoRA adapters on Embeddings LM-head UNLESS you are trying to add special tokens to the vocabulary (e.g. for chatml)

Does this implementation enable this? Or is it now that we can add LoRA adapters to the Vocab, but not to the lm-head?

@njhill
Copy link
Member

njhill commented Jul 12, 2024

Yes the vLLM multi-LoRA support in general supports adding tokens/embeddings including to lm_head layer (we have used it e.g. for Mixtral).

Also prevent initialization of lm_head weights prior to
weight tying

Signed-off-by: Max de Bayser <[email protected]>
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @maxdebayser!

@njhill
Copy link
Member

njhill commented Jul 29, 2024

@followumesh any chance you could review this one?

@njhill njhill changed the title Fix the lm_head in gptbigcode in lora mode [BigFix] Fix the lm_head in gpt_bigcode in lora mode Jul 29, 2024
@followumesh
Copy link
Contributor

followumesh commented Aug 1, 2024

@maxdebayser A sanity check question, does the LoRA correctly replace ParallelLMHead(VocabParallelEmbedding) with the corresponding LoRA layer (I assume VocabParallelEmbeddingWithLoRA)?

@maxdebayser
Copy link
Contributor Author

@followumesh , no the lm_head should not be replaced with a Lora class because the LoRA weights are applied in LogitsProcessorWithLora:

logits = lm_head.linear_method.apply(lm_head, hidden_states)

@hmellor
Copy link
Member

hmellor commented Mar 17, 2025

I don't think the failure is related as the failing test uses Gemma

@maxdebayser
Copy link
Contributor Author

I've synced with main now to trigger a new build with the latest code.

@jeejeelee
Copy link
Collaborator

I'm very sorry for missing this PR. I will look at it ASAP. Thank you.

@jeejeelee
Copy link
Collaborator

@maxdebayser Perhaps directly deleting embedding_modules would be more appropriate?

@maxdebayser
Copy link
Contributor Author

@jeejeelee thanks for your suggestion. It works when the embedding and lm_head modules are tied, but not when they aren't.

Let me try to give a little bit more context here. We have a GPT-Bigcode model with weight tie and some LoRA adapters that unfortunately are not public, but hopefully I can explain what is going on.

As of version 0.7.4 , when I try to load the model with --enable-lora it fails to load with the following error:

ERROR 03-21 19:04:02 [engine.py:411]   File "/home/vllm/my-vllm2/lib64/python3.12/site-packages/vllm/worker/model_runner.py", line 1243, in profile_run
ERROR 03-21 19:04:02 [engine.py:411]     self._dummy_run(max_num_batched_tokens, max_num_seqs)
ERROR 03-21 19:04:02 [engine.py:411]   File "/home/vllm/my-vllm2/lib64/python3.12/site-packages/vllm/worker/model_runner.py", line 1354, in _dummy_run
ERROR 03-21 19:04:02 [engine.py:411]     self.execute_model(model_input, kv_caches, intermediate_tensors)
ERROR 03-21 19:04:02 [engine.py:411]   File "/home/vllm/my-vllm2/lib64/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 03-21 19:04:02 [engine.py:411]     return func(*args, **kwargs)
ERROR 03-21 19:04:02 [engine.py:411]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 03-21 19:04:02 [engine.py:411]   File "/home/vllm/my-vllm2/lib64/python3.12/site-packages/vllm/worker/model_runner.py", line 1669, in execute_model
ERROR 03-21 19:04:02 [engine.py:411]     self.set_active_loras(model_input.lora_requests,
ERROR 03-21 19:04:02 [engine.py:411]   File "/home/vllm/my-vllm2/lib64/python3.12/site-packages/vllm/worker/model_runner.py", line 1371, in set_active_loras
ERROR 03-21 19:04:02 [engine.py:411]     self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
ERROR 03-21 19:04:02 [engine.py:411]   File "/home/vllm/my-vllm2/lib64/python3.12/site-packages/vllm/lora/worker_manager.py", line 167, in set_active_adapters
ERROR 03-21 19:04:02 [engine.py:411]     set_active_adapters_worker(requests, mapping, self._apply_adapters,
ERROR 03-21 19:04:02 [engine.py:411]   File "/home/vllm/my-vllm2/lib64/python3.12/site-packages/vllm/adapter_commons/utils.py", line 54, in set_active_adapters_worker
ERROR 03-21 19:04:02 [engine.py:411]     apply_adapters_func(requests)
ERROR 03-21 19:04:02 [engine.py:411]   File "/home/vllm/my-vllm2/lib64/python3.12/site-packages/vllm/lora/worker_manager.py", line 227, in _apply_adapters
ERROR 03-21 19:04:02 [engine.py:411]     self.add_adapter(lora)
ERROR 03-21 19:04:02 [engine.py:411]   File "/home/vllm/my-vllm2/lib64/python3.12/site-packages/vllm/lora/worker_manager.py", line 250, in add_adapter
ERROR 03-21 19:04:02 [engine.py:411]     self._adapter_manager.activate_adapter(lora_request.lora_int_id)
ERROR 03-21 19:04:02 [engine.py:411]   File "/home/vllm/my-vllm2/lib64/python3.12/site-packages/vllm/lora/models.py", line 720, in activate_adapter
ERROR 03-21 19:04:02 [engine.py:411]     result = super().activate_adapter(lora_id)
ERROR 03-21 19:04:02 [engine.py:411]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-21 19:04:02 [engine.py:411]   File "/home/vllm/my-vllm2/lib64/python3.12/site-packages/vllm/lora/models.py", line 405, in activate_adapter
ERROR 03-21 19:04:02 [engine.py:411]     module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
ERROR 03-21 19:04:02 [engine.py:411]   File "/home/vllm/my-vllm2/lib64/python3.12/site-packages/vllm/lora/layers.py", line 1070, in set_lora
ERROR 03-21 19:04:02 [engine.py:411]     0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
ERROR 03-21 19:04:02 [engine.py:411]                                            ^^^^^^
ERROR 03-21 19:04:02 [engine.py:411] RuntimeError: The size of tensor a (6144) must match the size of tensor b (49408) at non-singleton dimension 1

With the changes in this PR and also with your suggestion the model loads without errors and the results with and without adapter are the same:

curl http://localhost:8000/v1/completions   -H "Content-Type: application/json"   -d '{
    "model": "my-gpt-bigcode-model-with-weight-tie",
    "prompt": ["Input: Our waitress seemed less than happy about the prix fixe dinner choices and at one point said, Do you really need to hear the specials? Response:"],
    "max_tokens": 10,
    "temperature": 0
  }'| jq
{
  "object": "text_completion",
  "model": "my-gpt-bigcode-model-with-weight-tie",
  "choices": [
    {
      "index": 0,
      "text": " I don't know, I just don't like",
      "logprobs": null,
      "finish_reason": "length",
      "stop_reason": null,
      "prompt_logprobs": null
    }
  ]
}
curl http://localhost:8000/v1/completions   -H "Content-Type: application/json"   -d '{
    "model": "my-lora",
    "prompt": ["Input: Our waitress seemed less than happy about the prix fixe dinner choices and at one point said, Do you really need to hear the specials? Response:"],
    "max_tokens": 10,
    "temperature": 0
  }'| jq
{
  "object": "text_completion",
  "model": "my-lora",
  "choices": [
    {
      "index": 0,
      "text": " waitress: negative, specials: neutral,",
      "logprobs": null,
      "finish_reason": "length",
      "stop_reason": null,
      "prompt_logprobs": null
    }
  ]
}

But, for testing purposes I have the same model where I duplicated the weights for the lm_head and set "tie_word_embeddings": false. When I run this model with the changes in this PR I get the same results as above. Whereas when I just delete embedding_modules, the model loads without crashing but the outputs change:

curl http://localhost:8000/v1/completions   -H "Content-Type: application/json"   -d '{
    "model": "my-gpt-bigcode-model-without-weight-tie",
    "prompt": ["Input: Our waitress seemed less than happy about the prix fixe dinner choices and at one point said, Do you really need to hear the specials? Response:"],
    "max_tokens": 10,
    "temperature": 0
  }'| jq
{
  "object": "text_completion",
  "model": "my-gpt-bigcode-model-without-weight-tie",
  "choices": [
    {
      "index": 0,
      "text": "",
      "logprobs": null,
      "finish_reason": "stop",
      "stop_reason": null,
      "prompt_logprobs": null
    }
  ]
}

curl http://localhost:8000/v1/completions   -H "Content-Type: application/json"   -d '{
    "model": "my-lora",                                                     
    "prompt": ["Input: Our waitress seemed less than happy about the prix fixe dinner choices and at one point said, Do you really need to hear the specials? Response:"],
    "max_tokens": 10,
    "temperature": 0
  }'| jq
{
  "object": "text_completion",
  "model": "my-lora",
  "choices": [
    {
      "index": 0,
      "text": "",
      "logprobs": null,
      "finish_reason": "stop",
      "stop_reason": null,
      "prompt_logprobs": null
    }
  ]
}

What's interesting is that the model without weight tie has this behavior even with version 0.7.4 without --enable-lora.

@jeejeelee
Copy link
Collaborator

@maxdebayser Thanks for your explanation.

But, for testing purposes I have the same model where I duplicated the weights for the lm_head and set "tie_word_embeddings": false. When I run this model with the changes in this PR I get the same results as above. Whereas when I just delete embedding_modules, the model loads without crashing but the outputs change:

embedding_modules only works for lora, so your second experiment is not related to deleting embedding_modules. Can you upgrade to the latest version of vllm? I tested setting tie_word_embeddings to False, and it throws an error in the recent main branch

If your lora needs to support embedding_modules, I think we can keep it

Signed-off-by: Max de Bayser <[email protected]>
Signed-off-by: Max de Bayser <[email protected]>
@maxdebayser
Copy link
Contributor Author

@jeejeelee , I finally had time to come back to this. Thanks a lot for your suggestions. The only missing piece to support the model without weight tie was to prevent the loader from skipping the lm_head module in that case.

@joerunde joerunde added the ready ONLY add when PR is ready to merge/full CI is needed label May 7, 2025
Copy link

mergify bot commented May 15, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @maxdebayser.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 15, 2025
@mergify mergify bot removed the needs-rebase label May 15, 2025
@jeejeelee jeejeelee self-assigned this May 26, 2025
@jeejeelee jeejeelee changed the title [BigFix] Fix the lm_head in gpt_bigcode in lora mode [Bugfix] Fix the lm_head in gpt_bigcode in lora mode May 26, 2025
@jeejeelee jeejeelee merged commit 561b77a into vllm-project:main May 26, 2025
66 checks passed
gshtras added a commit to ROCm/vllm that referenced this pull request May 27, 2025
* Add files via uploadAdd fused MoE kernel tuning configs (fp8_w8a8) for DeepSeek V3/R1 on a single-node 8x NVIDIA H20 96GB setup (vllm-project#18337)

* [Misc] Fix typo (vllm-project#18330)

* Neuron up mistral (vllm-project#18222)

Signed-off-by: Satyajith Chilappagari <[email protected]>

* fix CUDA_check redefinition in vllm-project#17918 (vllm-project#18287)

Signed-off-by: Lucia Fang <[email protected]>
Co-authored-by: Lucia (Lu) Fang <[email protected]>

* [neuron] fix authorization issue (vllm-project#18364)

Signed-off-by: Liangfu Chen <[email protected]>

* [Misc] Allow `AutoWeightsLoader` to skip loading weights with specific substr in name (vllm-project#18358)

Signed-off-by: Isotr0py <[email protected]>

* [Core] [Bugfix]: tensor parallel with prompt embeds (vllm-project#18171)

Signed-off-by: Nan2018 <[email protected]>
Co-authored-by: Andrew Sansom <[email protected]>

* [release] Change dockerhub username for TPU release (vllm-project#18389)

* [Bugfix] fix adding bias twice in ipex GPTQ quantization (vllm-project#18363)

Signed-off-by: rand-fly <[email protected]>

* [doc] update env variable export (vllm-project#18391)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [Misc] Add LoRA code owner (vllm-project#18387)

Signed-off-by: Jee Jee Li <[email protected]>

* Update cpu.txt (vllm-project#18398)

Signed-off-by: 汪志鹏 <[email protected]>

* [CI] Add mteb testing to test the accuracy of the embedding model (vllm-project#17175)

* [Bugfix] Fix MRoPE Errors in the Qwen-VL Model When Processing Pure Text (vllm-project#18407)

Co-authored-by: 松灵 <[email protected]>

* [Misc] refactor prompt embedding examples (vllm-project#18405)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [Minor] Rename quantization nvfp4 to modelopt_fp4 (vllm-project#18356)

Signed-off-by: mgoin <[email protected]>

* [Model] use AutoWeightsLoader for bloom (vllm-project#18300)

Signed-off-by: calvin chen <[email protected]>

* [Kernel] update comment for KV shape in unified triton attn (vllm-project#18099)

Signed-off-by: haochengxia <[email protected]>

* fix:Build torch wheel inline rather than picking from nightly (vllm-project#18351)

Signed-off-by: Dilip Gowda Bhagavan <[email protected]>

* [TPU] Re-enable the Pallas MoE kernel (vllm-project#18025)

Signed-off-by: Michael Goin <[email protected]>

* [Bugfix] config.head_dim is now explicitly set to None (vllm-project#18432)

Signed-off-by: Gregory Shtrasberg <[email protected]>

* [Bug] Fix moe_sum signature (vllm-project#18440)

Signed-off-by: Bill Nell <[email protected]>

* Revert "[Bugfix] Fix MRoPE Errors in the Qwen-VL Model When Processing Pure Text (vllm-project#18407)" (vllm-project#18456)

Signed-off-by: DarkLight1337 <[email protected]>

* [Bugfix][Failing Test] Fix nixl connector test when promt size < block size (vllm-project#18429)

Signed-off-by: wwl2755 <[email protected]>

* [Misc] MultiConnector._connectors type (vllm-project#18423)

Signed-off-by: nicklucche <[email protected]>

* [Frontend] deprecate `--device` arg (vllm-project#18399)

Signed-off-by: Kebe <[email protected]>

* [V1] Fix general plugins not loaded in engine for multiproc (vllm-project#18326)

Signed-off-by: Yong Hoon Shin <[email protected]>

* [Misc] refactor disaggregated-prefill-v1 example (vllm-project#18474)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [Bugfix][Failing Test] Fix test_events.py (vllm-project#18460)

Signed-off-by: rabi <[email protected]>

* [MODEL] FalconH1 (vllm-project#18406)

Signed-off-by: dhia.rhaiem <[email protected]>
Co-authored-by: younesbelkada <[email protected]>
Co-authored-by: Ilyas Chahed <[email protected]>
Co-authored-by: Jingwei Zuo <[email protected]>

* [Doc] fix arg docstring in linear layers (vllm-project#18410)

Signed-off-by: giantcroc <[email protected]>

* [Bugfix] Reduce moe_sum test size to avoid OOM (vllm-project#18484)

Signed-off-by: Bill Nell <[email protected]>

* [Build] fix Dockerfile shell (vllm-project#18402)

* [Misc] Update deprecation message for `--enable-reasoning` (vllm-project#18404)

* [ROCm][Kernel][V1] Enable AMD Radeon GPU Custom Paged Attention on v1 (vllm-project#17004)

Signed-off-by: Hosang Yoon <[email protected]>

* Remove incorrect env value

* Revert "[v1] Support multiple KV cache groups in GPU model runner (vllm-project#17945) (vllm-project#18459)

Signed-off-by: Mark McLoughlin <[email protected]>

* [FEAT][ROCm] Upgrade AITER MLA v1 backend (vllm-project#18338)

Signed-off-by: vllmellm <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>

* [Bugfix] Consistent ascii handling in tool parsers (vllm-project#17704)

Signed-off-by: Sebastian Schönnenbeck <[email protected]>

* [FalconH1] Fix output dtype in RMSNorm fallback path for Falcon-H1 (e.g. 0.5B) (vllm-project#18500)

Signed-off-by: dhia.rhaiem <[email protected]>
Co-authored-by: younesbelkada <[email protected]>
Co-authored-by: Ilyas Chahed <[email protected]>
Co-authored-by: Jingwei Zuo <[email protected]>

* [MISC] update project urls in pyproject.toml (vllm-project#18519)

Signed-off-by: Andy Xie <[email protected]>

* [CI] Fix race condition with StatelessProcessGroup.barrier (vllm-project#18506)

Signed-off-by: Russell Bryant <[email protected]>

* Intialize io_thread_pool attribute in the beginning. (vllm-project#18331)

Signed-off-by: rabi <[email protected]>

* [Bugfix] Inconsistent token calculation compared to HF in llava family (vllm-project#18479)

Signed-off-by: jaycha <[email protected]>

* [BugFix][DP] Send DP wave completion only from `dp_rank==0` (vllm-project#18502)

Signed-off-by: Nick Hill <[email protected]>
Co-authored-by: kourosh hakhamaneshi <[email protected]>

* [Bugfix][Model] Make Olmo2Model weight loading return loaded weights (vllm-project#18504)

Signed-off-by: Shane A <[email protected]>

* [Bugfix] Fix LoRA test (vllm-project#18518)

Signed-off-by: Jee Jee Li <[email protected]>

* [Doc] Fix invalid JSON in example args (vllm-project#18527)

Signed-off-by: DarkLight1337 <[email protected]>

* [Neuron] Update Dockerfile.neuron to use latest neuron release (2.23) (vllm-project#18512)

Signed-off-by: Satyajith Chilappagari <[email protected]>

* Update default neuron config for speculation (vllm-project#18274)

Signed-off-by: Elaine Zhao <[email protected]>
Co-authored-by: Shashwat Srijan <[email protected]>
Co-authored-by: Aakash Shetty <[email protected]>

* Order sequence ids + config update to support specifying custom quantization layers (vllm-project#18279)

Signed-off-by: Elaine Zhao <[email protected]>
Co-authored-by: Tailin Pan <[email protected]>
Co-authored-by: Rishabh Rajesh <[email protected]>
Co-authored-by: Yishan McNabb <[email protected]>
Co-authored-by: Patrick Lange <[email protected]>
Co-authored-by: Maxwell Goldberg <[email protected]>
Co-authored-by: Aakash Shetty <[email protected]>

* [Bugfix] Fix MRoPE Errors in the Qwen-VL Model When Processing Pure Text (vllm-project#18526)

Co-authored-by: 松灵 <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: DarkLight1337 <[email protected]>

* [Bugfix] Add kwargs to RequestOutput __init__ to be forward compatible (vllm-project#18513)

Signed-off-by: Linkun <[email protected]>

* [CI/Build] Update bamba test model location (vllm-project#18544)

Signed-off-by: Harry Mellor <[email protected]>

* [Doc] Support --stream arg in openai_completion_client.py script (vllm-project#18388)

Signed-off-by: googs1025 <[email protected]>

* [Bugfix] Use random hidden states in dummy sampler run (vllm-project#18543)

Signed-off-by: Bowen Wang <[email protected]>

* [Doc] Add stream flag for chat completion example (vllm-project#18524)

Signed-off-by: calvin chen <[email protected]>

* [BugFix][CPU] Fix x86 SHM distributed module initialization (vllm-project#18536)

Signed-off-by: jiang.li <[email protected]>

* [Misc] improve Automatic Prefix Caching example (vllm-project#18554)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [Misc] Call `ndarray.tobytes()` directly instead of `ndarray.data.tobytes()` (vllm-project#18347)

Signed-off-by: Lukas Geiger <[email protected]>

* [Bugfix] make `test_openai_schema.py` pass (vllm-project#18224)

Signed-off-by: David Xia <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>

* [Platform] Move platform check to right place (vllm-project#18470)

Signed-off-by: wangxiyuan <[email protected]>

* [Compile][Platform] Make PiecewiseBackend pluggable and extendable (vllm-project#18076)

Signed-off-by: Mengqing Cao <[email protected]>
Co-authored-by: youkaichao <[email protected]>

* [Build/CI] Fix CUDA 11.8 build (vllm-project#17679)

Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>

* [Tool] Add NIXL installation script (vllm-project#18172)

Signed-off-by: Linkun <[email protected]>

* [V1][Spec Decode][Bugfix] Load quantize weights for EAGLE (vllm-project#18290)

* [Frontend][Bug Fix] Update llama4 pythonic jinja template and llama4_pythonic parser (vllm-project#17917)

Signed-off-by: Kai Wu <[email protected]>

* [Frontend] [Core] Add Tensorizer support for V1, LoRA adapter serialization and deserialization (vllm-project#17926)

Signed-off-by: Sanger Steel <[email protected]>

* [AMD] [P/D] Compute num gpus for ROCm correctly in run_accuracy_test.sh (vllm-project#18568)

Signed-off-by: Randall Smith <[email protected]>

* Re-submit: Fix: Proper RGBA -> RGB conversion for PIL images. (vllm-project#18569)

Signed-off-by: Chenheli Hua <[email protected]>

* [V1][Spec Decoding] Use model_loader.get_model() to load models (vllm-project#18273)

Signed-off-by: Mark McLoughlin <[email protected]>

* Enable hybrid attention models for Transformers backend (vllm-project#18494)

Signed-off-by: Harry Mellor <[email protected]>

* [Misc] refactor: simplify input validation and num_requests handling in _convert_v1_inputs (vllm-project#18482)

Signed-off-by: googs1025 <[email protected]>

* [BugFix] Increase TP execute_model timeout (vllm-project#18558)

Signed-off-by: Nick Hill <[email protected]>

* [Bugfix] Set `KVTransferConfig.engine_id` in post_init (vllm-project#18576)

Signed-off-by: Linkun Chen <[email protected]>

* [Spec Decode] Make EAGLE3 draft token ID mapping optional (vllm-project#18488)

Signed-off-by: Benjamin Chislett <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>

* [Neuron] Remove bypass on EAGLEConfig and add a test (vllm-project#18514)

Signed-off-by: Elaine Zhao <[email protected]>

* [Bugfix][Benchmarks] Fix a benchmark of deepspeed-mii backend to use api_key (vllm-project#17291)

Signed-off-by: Teruaki Ishizaki <[email protected]>

* [Misc] Replace `cuda` hard code with `current_platform` (vllm-project#16983)

Signed-off-by: shen-shanshan <[email protected]>

* [Hardware] correct method signatures for HPU,ROCm,XPU (vllm-project#18551)

Signed-off-by: Andy Xie <[email protected]>

* [V1] [Bugfix] eagle bugfix and enable correct lm_head for multimodal (vllm-project#18034)

Signed-off-by: Ronald Xu <[email protected]>

* [Feature]Add async tensor parallelism using compilation pass (vllm-project#17882)

Signed-off-by: cascade812 <[email protected]>

* [Doc] Update quickstart and install for cu128 using `--torch-backend=auto` (vllm-project#18505)

Signed-off-by: mgoin <[email protected]>

* [Feature][V1]: suupports cached_tokens in response usage (vllm-project#18149)

Co-authored-by: simon-mo <[email protected]>

* [Bugfix] Add half type support in reshape_and_cache_cpu_impl on x86 cpu platform (vllm-project#18430)

Signed-off-by: Yuqi Zhang <[email protected]>
Co-authored-by: Yuqi Zhang <[email protected]>

* Migrate docs from Sphinx to MkDocs (vllm-project#18145)

Signed-off-by: Harry Mellor <[email protected]>

* Revert "[V1] [Bugfix] eagle bugfix and enable correct lm_head for multimodal (vllm-project#18034)" (vllm-project#18600)

Signed-off-by: DarkLight1337 <[email protected]>

* [Bugfix][Model] Fix baichuan model loader for tp (vllm-project#18597)

Signed-off-by: Mengqing Cao <[email protected]>

* [V0][Bugfix] Fix parallel sampling performance regression when guided decoding is enabled (vllm-project#17731)

Signed-off-by: Madeesh Kannan <[email protected]>
Co-authored-by: Russell Bryant <[email protected]>

* Add myself as docs code owner (vllm-project#18605)

Signed-off-by: Harry Mellor <[email protected]>

* [Hardware][CPU] Update intel_extension_for_pytorch 2.7.0 and move to `requirements/cpu.txt`  (vllm-project#18542)

Signed-off-by: Kay Yan <[email protected]>

* [CI] fix kv_cache_type argument (vllm-project#18594)

Signed-off-by: Andy Xie <[email protected]>

* [Doc] Fix indent of contributing to vllm (vllm-project#18611)

Signed-off-by: Zerohertz <[email protected]>

* Replace `{func}` with mkdocs style links (vllm-project#18610)

Signed-off-by: Harry Mellor <[email protected]>

* [CI/Build] Fix V1 flag being set in entrypoints tests (vllm-project#18598)

Signed-off-by: DarkLight1337 <[email protected]>

* Fix examples with code blocks in docs (vllm-project#18609)

Signed-off-by: Harry Mellor <[email protected]>

* [Bugfix] Fix transformers model impl ignored for mixtral quant (vllm-project#18602)

Signed-off-by: Tristan Leclercq <[email protected]>

* Include private attributes in API documentation (vllm-project#18614)

Signed-off-by: Harry Mellor <[email protected]>

* [Misc] add Haystack integration (vllm-project#18601)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [Bugfix][Build/CI] Fixup CUDA compiler version check for CUDA_SUPPORTED_ARCHS (vllm-project#18579)

* [Doc] Fix markdown list indentation for MkDocs rendering (vllm-project#18620)

Signed-off-by: Zerohertz <[email protected]>

* [Doc] Use a different color for the announcement (vllm-project#18616)

Signed-off-by: DarkLight1337 <[email protected]>

* Refactor pplx init logic to make it modular (prepare for deepep) (vllm-project#18200)

Signed-off-by: youkaichao <[email protected]>

* Fix figures in design doc (vllm-project#18612)

Signed-off-by: Harry Mellor <[email protected]>

* [Docs] Change mkdocs to not use directory urls (vllm-project#18622)

Signed-off-by: mgoin <[email protected]>

* [v1] Redo "Support multiple KV cache groups in GPU model runner (vllm-project#17945)" (vllm-project#18593)

Signed-off-by: Chen Zhang <[email protected]>

* [Doc] fix list formatting (vllm-project#18624)

Signed-off-by: David Xia <[email protected]>

* [Doc] Fix top-level API links/docs (vllm-project#18621)

Signed-off-by: DarkLight1337 <[email protected]>

* [Doc] Avoid documenting dynamic / internal modules (vllm-project#18626)

Signed-off-by: DarkLight1337 <[email protected]>

* [Doc] Fix broken links and unlinked docs, add shortcuts to home sidebar (vllm-project#18627)

Signed-off-by: DarkLight1337 <[email protected]>

* [V1] Support Deepseek MTP (vllm-project#18435)

Signed-off-by: Rui Qiao <[email protected]>
Signed-off-by: YaoJiayi <[email protected]>
Co-authored-by: Rui Qiao <[email protected]>

* Use prebuilt FlashInfer x86_64 PyTorch 2.7 CUDA 12.8 wheel for CI (vllm-project#18537)

Signed-off-by: Huy Do <[email protected]>

* [CI] Enable test_initialization to run on V1 (vllm-project#16736)

Signed-off-by: mgoin <[email protected]>

* [Doc] Update references to doc files (vllm-project#18637)

Signed-off-by: DarkLight1337 <[email protected]>

* [ModelOpt] Introduce VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE env var to control blockscale tensor allocation (vllm-project#18160)

Signed-off-by: Pavani Majety <[email protected]>

* [Bugfix] Migrate to REGEX Library to prevent catastrophic backtracking (vllm-project#18454)

Signed-off-by: Crucifixion-Fxl <[email protected]>
Co-authored-by: Crucifixion-Fxl <[email protected]>

* [Bugfix][Nixl] Fix Preemption Bug (vllm-project#18631)

Signed-off-by: [email protected] <[email protected]>

* config.py: Clarify that only local GGUF checkpoints are supported. (vllm-project#18623)

Signed-off-by: Mathieu Bordere <[email protected]>

* FIX MOE issue in AutoRound format (vllm-project#18586)

Signed-off-by: wenhuach21 <[email protected]>

* [V1][Spec Decode] Small refactors to improve eagle bookkeeping performance (vllm-project#18424)

Signed-off-by: qizixi <[email protected]>

* [Frontend] improve vllm serve --help display (vllm-project#18643)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [Model] Add support for Qwen2.5-Omni-7B-AWQ (Qwen2_5OmniForConditionalGeneration) (vllm-project#18647)

* [V1][Spec Decode] Support multi-layer eagle draft model (vllm-project#18030)

Signed-off-by: qizixi <[email protected]>

* [Doc] Update README links, mark external links (vllm-project#18635)

Signed-off-by: DarkLight1337 <[email protected]>

* [MISC][pre-commit] Add pre-commit check for triton import (vllm-project#17716)

Signed-off-by: Mengqing Cao <[email protected]>

* [Doc] Fix indentation problems in V0 Paged Attention docs (vllm-project#18659)

Signed-off-by: DarkLight1337 <[email protected]>

* [Doc] Add community links (vllm-project#18657)

Signed-off-by: DarkLight1337 <[email protected]>

* [Model] use AutoWeightsLoader for gpt2 (vllm-project#18625)

Signed-off-by: zt2370 <[email protected]>

* [Doc] Reorganize user guide (vllm-project#18661)

Signed-off-by: DarkLight1337 <[email protected]>

* [CI/Build] `chmod +x` to `cleanup_pr_body.sh` (vllm-project#18650)

Signed-off-by: DarkLight1337 <[email protected]>

* [MISC] typo fix and clean import (vllm-project#18664)

Signed-off-by: Andy Xie <[email protected]>

* [BugFix] Fix import error for fused_moe (vllm-project#18642)

Signed-off-by: wangxiyuan <[email protected]>

* [CI] enforce import regex instead of re (vllm-project#18665)

Signed-off-by: Aaron Pham <[email protected]>

* fix(regression): clone from reference items (vllm-project#18662)

Signed-off-by: Aaron Pham <[email protected]>

* [CI/Build] fix permission denied issue (vllm-project#18645)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [BugFix][Spec Decode] Improve Prefix Caching Logic in Speculative Decoding (vllm-project#18668)

Signed-off-by: Woosuk Kwon <[email protected]>

* [V1] Fix _pickle.PicklingError: Can't pickle <class 'transformers_modules.deepseek-ai.DeepSeek-V2-Lite... (vllm-project#18640)

Signed-off-by: Seiji Eicher <[email protected]>

* [MISC] correct signature for LoaderFunction (vllm-project#18670)

Signed-off-by: Andy Xie <[email protected]>

* [Misc]Replace `cuda` hard code with `current_platform` in Ray (vllm-project#14668)

Signed-off-by: noemotiovon <[email protected]>

* [Misc][ModelScope] Change to use runtime VLLM_USE_MODELSCOPE (vllm-project#18655)

Signed-off-by: Mengqing Cao <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: Isotr0py <[email protected]>

* [VLM] Initialize video input support for InternVL models (vllm-project#18499)

Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>

* Speed up the `kernels/quantization/` tests (vllm-project#18669)

Signed-off-by: mgoin <[email protected]>

* [BUGFIX] catch subclass first for try...except (vllm-project#18672)

Signed-off-by: Andy Xie <[email protected]>

* [Misc] Reduce logs on startup (vllm-project#18649)

Signed-off-by: DarkLight1337 <[email protected]>

* [doc] fix broken links (vllm-project#18671)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [doc] improve readability (vllm-project#18675)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [Bugfix] Fix cpu usage and cache hit stats reporting on cpu environment (vllm-project#18674)

Signed-off-by: zzzyq <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>

* [CI/build] fix no regex (vllm-project#18676)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [Misc] small improve (vllm-project#18680)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [Bugfix] Fix profiling dummy data for Pixtral (vllm-project#18677)

Signed-off-by: DarkLight1337 <[email protected]>

* [Core][Multimodal] Convert PIL Image to array without data copy when hashing (vllm-project#18682)

Signed-off-by: Lukas Geiger <[email protected]>

* [CI/Build][Doc] Update `gte-Qwen2-1.5B-instruct` usage (vllm-project#18683)

Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: Isotr0py <[email protected]>

* [Misc] Fixed the abnormally high TTFT issue in the PD disaggregation example (vllm-project#18644)

Signed-off-by: zhaohaidao <[email protected]>
Signed-off-by: zhaohaiyuan <[email protected]>
Co-authored-by: zhaohaiyuan <[email protected]>

* refactor: simplify request handler, use positive condition check for handler assignment (vllm-project#18690)

Signed-off-by: googs1025 <[email protected]>

* [Bugfix] Fix the lm_head in gpt_bigcode in lora mode (vllm-project#6357)

Signed-off-by: Max de Bayser <[email protected]>
Signed-off-by: Max de Bayser <[email protected]>

* [CI] add missing argument (vllm-project#18694)

Signed-off-by: Andy Xie <[email protected]>

* [GH] Add issue template for reporting CI failures (vllm-project#18696)

Signed-off-by: DarkLight1337 <[email protected]>

* [Doc] Fix issue template format (vllm-project#18699)

Signed-off-by: DarkLight1337 <[email protected]>

* [Bugfix] Fix Mistral-format models with sliding window (vllm-project#18693)

Signed-off-by: DarkLight1337 <[email protected]>

* [CI/Build] Replace `math.isclose` with `pytest.approx` (vllm-project#18703)

Signed-off-by: DarkLight1337 <[email protected]>

* [CI] fix dump_input for str type (vllm-project#18697)

Signed-off-by: Andy Xie <[email protected]>

* [Model] Add support for YARN in NemotronNAS models (vllm-project#18427)

Signed-off-by: Nave Assaf <[email protected]>

* [CI/Build] Split pooling and generation extended language models tests in CI (vllm-project#18705)

Signed-off-by: Isotr0py <[email protected]>

* [Hardware][Intel-Gaudi] [CI/Build] Add tensor parallel size = 2 test to HPU CI (vllm-project#18709)

Signed-off-by: Lukasz Durejko <[email protected]>

* [Misc] add AutoGen integration (vllm-project#18712)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>

* [Bugfix]: handle hf-xet CAS error when loading Qwen3 weights in vLLM (vllm-project#18701)

* [Doc] Improve API docs (vllm-project#18713)

Signed-off-by: DarkLight1337 <[email protected]>

* [Doc] Move examples and further reorganize user guide (vllm-project#18666)

Signed-off-by: DarkLight1337 <[email protected]>

* [Bugfix] Fix Llama GGUF initialization (vllm-project#18717)

Signed-off-by: DarkLight1337 <[email protected]>

* [V1][Sampler] Improve performance of FlashInfer sampling by sampling logits instead of probs (vllm-project#18608)

* Convert `examples` to `ruff-format` (vllm-project#18400)

Signed-off-by: Harry Mellor <[email protected]>

* [Model][Gemma3] Simplify image input validation (vllm-project#18710)

Signed-off-by: Lukas Geiger <[email protected]>

* [Misc] improve web section group title display (vllm-project#18684)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* [V1][Quantization] Add CUDA graph compatible v1 GGUF support (vllm-project#18646)

Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>

* [Model][Gemma3] Cast image pixel values already on CPU (vllm-project#18732)

Signed-off-by: Lukas Geiger <[email protected]>

* [FEAT] [ROCm] Upgrade AITER Fused MoE kernels. (vllm-project#18271)

Signed-off-by: vllmellm <[email protected]>

* [Doc] Update OOT model docs (vllm-project#18742)

Signed-off-by: DarkLight1337 <[email protected]>

* [Doc] Update reproducibility doc and example (vllm-project#18741)

Signed-off-by: DarkLight1337 <[email protected]>

* [Misc] improve docs (vllm-project#18734)

Signed-off-by: reidliu41 <[email protected]>
Co-authored-by: reidliu41 <[email protected]>

* feat(rocm-support): support mamba2 on rocm (vllm-project#18565)

Signed-off-by: Islam Almersawi <[email protected]>
Co-authored-by: Islam Almersawi <[email protected]>

* [Hardware][Intel-Gaudi] [CI/Build] Fix multiple containers using the same name in run-hpu-test.sh (vllm-project#18752)

Signed-off-by: Lukasz Durejko <[email protected]>

* [Doc] cleanup deprecated flag for doc (vllm-project#18715)

Signed-off-by: calvin chen <[email protected]>

* Minor fix about MooncakeStoreConnector (vllm-project#18721)

Signed-off-by: baoloongmao <[email protected]>

* [Build] fix cpu build missing libtbbmalloc.so (vllm-project#18744)

Signed-off-by: Kebe <[email protected]>

* [BUG FIX] minicpm (vllm-project#18739)

Signed-off-by: huangyuxiang03 <[email protected]>
Co-authored-by: huangyuxiang03 <[email protected]>

* [Doc]  Convert Sphinx directives ( `{class}`, `{meth}`, `{attr}`, ...) to MkDocs format for better documentation linking (vllm-project#18663)

Signed-off-by: Zerohertz <[email protected]>

* [CI/Build] Remove imports of built-in `re` (vllm-project#18750)

Signed-off-by: DarkLight1337 <[email protected]>

* [V1][Metrics] Add API for accessing in-memory Prometheus metrics (vllm-project#17010)

Signed-off-by: Mark McLoughlin <[email protected]>

* Disable prefix cache by default for benchmark (vllm-project#18639)

Signed-off-by: cascade812 <[email protected]>

* optimize get_kv_cache_torch_dtype (vllm-project#18531)

Signed-off-by: idellzheng <[email protected]>

* [Core] Automatically cast multi-modal input dtype (vllm-project#18756)

Signed-off-by: DarkLight1337 <[email protected]>

* [Bugfix] Mistral tool calling when content is list (vllm-project#18729)

Signed-off-by: mgoin <[email protected]>

---------

Signed-off-by: Satyajith Chilappagari <[email protected]>
Signed-off-by: Lucia Fang <[email protected]>
Signed-off-by: Liangfu Chen <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Nan2018 <[email protected]>
Signed-off-by: rand-fly <[email protected]>
Signed-off-by: reidliu41 <[email protected]>
Signed-off-by: Jee Jee Li <[email protected]>
Signed-off-by: 汪志鹏 <[email protected]>
Signed-off-by: mgoin <[email protected]>
Signed-off-by: calvin chen <[email protected]>
Signed-off-by: haochengxia <[email protected]>
Signed-off-by: Dilip Gowda Bhagavan <[email protected]>
Signed-off-by: Michael Goin <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: wwl2755 <[email protected]>
Signed-off-by: nicklucche <[email protected]>
Signed-off-by: Kebe <[email protected]>
Signed-off-by: Yong Hoon Shin <[email protected]>
Signed-off-by: rabi <[email protected]>
Signed-off-by: dhia.rhaiem <[email protected]>
Signed-off-by: giantcroc <[email protected]>
Signed-off-by: Hosang Yoon <[email protected]>
Signed-off-by: Mark McLoughlin <[email protected]>
Signed-off-by: vllmellm <[email protected]>
Signed-off-by: Sebastian Schönnenbeck <[email protected]>
Signed-off-by: Andy Xie <[email protected]>
Signed-off-by: Russell Bryant <[email protected]>
Signed-off-by: jaycha <[email protected]>
Signed-off-by: Nick Hill <[email protected]>
Signed-off-by: Shane A <[email protected]>
Signed-off-by: Elaine Zhao <[email protected]>
Signed-off-by: Linkun <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: googs1025 <[email protected]>
Signed-off-by: Bowen Wang <[email protected]>
Signed-off-by: jiang.li <[email protected]>
Signed-off-by: Lukas Geiger <[email protected]>
Signed-off-by: David Xia <[email protected]>
Signed-off-by: wangxiyuan <[email protected]>
Signed-off-by: Mengqing Cao <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Kai Wu <[email protected]>
Signed-off-by: Sanger Steel <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Chenheli Hua <[email protected]>
Signed-off-by: Linkun Chen <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Teruaki Ishizaki <[email protected]>
Signed-off-by: shen-shanshan <[email protected]>
Signed-off-by: Ronald Xu <[email protected]>
Signed-off-by: cascade812 <[email protected]>
Signed-off-by: Yuqi Zhang <[email protected]>
Signed-off-by: Madeesh Kannan <[email protected]>
Signed-off-by: Kay Yan <[email protected]>
Signed-off-by: Zerohertz <[email protected]>
Signed-off-by: Tristan Leclercq <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Rui Qiao <[email protected]>
Signed-off-by: YaoJiayi <[email protected]>
Signed-off-by: Huy Do <[email protected]>
Signed-off-by: Pavani Majety <[email protected]>
Signed-off-by: Crucifixion-Fxl <[email protected]>
Signed-off-by: [email protected] <[email protected]>
Signed-off-by: Mathieu Bordere <[email protected]>
Signed-off-by: wenhuach21 <[email protected]>
Signed-off-by: qizixi <[email protected]>
Signed-off-by: zt2370 <[email protected]>
Signed-off-by: Aaron Pham <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Seiji Eicher <[email protected]>
Signed-off-by: noemotiovon <[email protected]>
Signed-off-by: zzzyq <[email protected]>
Signed-off-by: zhaohaidao <[email protected]>
Signed-off-by: zhaohaiyuan <[email protected]>
Signed-off-by: Max de Bayser <[email protected]>
Signed-off-by: Max de Bayser <[email protected]>
Signed-off-by: Nave Assaf <[email protected]>
Signed-off-by: Lukasz Durejko <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Islam Almersawi <[email protected]>
Signed-off-by: baoloongmao <[email protected]>
Signed-off-by: huangyuxiang03 <[email protected]>
Signed-off-by: idellzheng <[email protected]>
Co-authored-by: sunyicode0012 <[email protected]>
Co-authored-by: Gong Shufan <[email protected]>
Co-authored-by: Satyajith Chilappagari <[email protected]>
Co-authored-by: Lucia Fang <[email protected]>
Co-authored-by: Lucia (Lu) Fang <[email protected]>
Co-authored-by: Liangfu Chen <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
Co-authored-by: Nan Qin <[email protected]>
Co-authored-by: Andrew Sansom <[email protected]>
Co-authored-by: Kevin H. Luu <[email protected]>
Co-authored-by: Random Fly <[email protected]>
Co-authored-by: Reid <[email protected]>
Co-authored-by: reidliu41 <[email protected]>
Co-authored-by: Jee Jee Li <[email protected]>
Co-authored-by: 汪志鹏 <[email protected]>
Co-authored-by: wang.yuqi <[email protected]>
Co-authored-by: 燃 <[email protected]>
Co-authored-by: 松灵 <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Calvin Chen <[email protected]>
Co-authored-by: Percy <[email protected]>
Co-authored-by: Dilip Gowda Bhagavan <[email protected]>
Co-authored-by: bnellnm <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: wwl2755 <[email protected]>
Co-authored-by: Nicolò Lucchesi <[email protected]>
Co-authored-by: Kebe <[email protected]>
Co-authored-by: Yong Hoon Shin <[email protected]>
Co-authored-by: Rabi Mishra <[email protected]>
Co-authored-by: Dhia Eddine Rhaiem <[email protected]>
Co-authored-by: younesbelkada <[email protected]>
Co-authored-by: Ilyas Chahed <[email protected]>
Co-authored-by: Jingwei Zuo <[email protected]>
Co-authored-by: GiantCroc <[email protected]>
Co-authored-by: Hyogeun Oh (오효근) <[email protected]>
Co-authored-by: Hosang <[email protected]>
Co-authored-by: Mark McLoughlin <[email protected]>
Co-authored-by: vllmellm <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
Co-authored-by: Sebastian Schoennenbeck <[email protected]>
Co-authored-by: Ning Xie <[email protected]>
Co-authored-by: Russell Bryant <[email protected]>
Co-authored-by: youngrok cha <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Co-authored-by: kourosh hakhamaneshi <[email protected]>
Co-authored-by: Shane A <[email protected]>
Co-authored-by: aws-elaineyz <[email protected]>
Co-authored-by: Shashwat Srijan <[email protected]>
Co-authored-by: Aakash Shetty <[email protected]>
Co-authored-by: Tailin Pan <[email protected]>
Co-authored-by: Rishabh Rajesh <[email protected]>
Co-authored-by: Yishan McNabb <[email protected]>
Co-authored-by: Patrick Lange <[email protected]>
Co-authored-by: Maxwell Goldberg <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: lkchen <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>
Co-authored-by: CYJiang <[email protected]>
Co-authored-by: Bowen Wang <[email protected]>
Co-authored-by: Li, Jiang <[email protected]>
Co-authored-by: Lukas Geiger <[email protected]>
Co-authored-by: David Xia <[email protected]>
Co-authored-by: wangxiyuan <[email protected]>
Co-authored-by: Mengqing Cao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Ekagra Ranjan <[email protected]>
Co-authored-by: Kai Wu <[email protected]>
Co-authored-by: Sanger Steel <[email protected]>
Co-authored-by: rasmith <[email protected]>
Co-authored-by: Chenheli Hua <[email protected]>
Co-authored-by: Benjamin Chislett <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>
Co-authored-by: Teruaki Ishizaki <[email protected]>
Co-authored-by: Shanshan Shen <[email protected]>
Co-authored-by: RonaldBXu <[email protected]>
Co-authored-by: cascade <[email protected]>
Co-authored-by: Chauncey <[email protected]>
Co-authored-by: simon-mo <[email protected]>
Co-authored-by: Yuqi Zhang <[email protected]>
Co-authored-by: Yuqi Zhang <[email protected]>
Co-authored-by: Madeesh Kannan <[email protected]>
Co-authored-by: Kay Yan <[email protected]>
Co-authored-by: Tristan Leclercq <[email protected]>
Co-authored-by: Simon Mo <[email protected]>
Co-authored-by: Chen Zhang <[email protected]>
Co-authored-by: Jiayi Yao <[email protected]>
Co-authored-by: Rui Qiao <[email protected]>
Co-authored-by: Huy Do <[email protected]>
Co-authored-by: Pavani Majety <[email protected]>
Co-authored-by: Feng XiaoLong <[email protected]>
Co-authored-by: Crucifixion-Fxl <[email protected]>
Co-authored-by: Robert Shaw <[email protected]>
Co-authored-by: Mathieu Borderé <[email protected]>
Co-authored-by: Wenhua Cheng <[email protected]>
Co-authored-by: qizixi <[email protected]>
Co-authored-by: Yuanhao WU <[email protected]>
Co-authored-by: ztang2370 <[email protected]>
Co-authored-by: Aaron Pham <[email protected]>
Co-authored-by: Seiji Eicher <[email protected]>
Co-authored-by: Chenguang Li <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
Co-authored-by: AlexZhao <[email protected]>
Co-authored-by: zhaohaiyuan <[email protected]>
Co-authored-by: Maximilien de Bayser <[email protected]>
Co-authored-by: Naveassaf <[email protected]>
Co-authored-by: Łukasz Durejko <[email protected]>
Co-authored-by: dylan <[email protected]>
Co-authored-by: almersawi <[email protected]>
Co-authored-by: Islam Almersawi <[email protected]>
Co-authored-by: Łukasz Durejko <[email protected]>
Co-authored-by: maobaolong <[email protected]>
Co-authored-by: Shawn Huang <[email protected]>
Co-authored-by: huangyuxiang03 <[email protected]>
Co-authored-by: chunxiaozheng <[email protected]>
amitm02 pushed a commit to amitm02/vllm that referenced this pull request Jun 1, 2025
Signed-off-by: Max de Bayser <[email protected]>
Signed-off-by: Max de Bayser <[email protected]>
Signed-off-by: amit <[email protected]>
minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
Signed-off-by: Max de Bayser <[email protected]>
Signed-off-by: Max de Bayser <[email protected]>
Signed-off-by: minpeter <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed unstale Recieved activity after being labelled stale
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants