Skip to content

Conversation

tdoublep
Copy link
Member

@tdoublep tdoublep commented Jun 16, 2025

Purpose

This PR adds support for Activated LoRA (a-LoRA): a new family of LoRA adapters that are invoked by including an invocation string in the prompt, and the weights are only adapted for the tokens in the sequence after the invocation string appears. This means that one can apply the aLoRA deep in a multi-turn interaction with the model without needing to recompute the entire KV cache. Instead, the adapter can use the KV cache from the base model right up until the adapter is invoked, thus significantly reducing TTFT.

paper: https://arxiv.org/abs/2504.12397
blog: https://research.ibm.com/blog/inference-friendly-aloras-lora

results from paper:
image

Implementation

We have tried to make the changes as unintrusive as possible (but happy to hear any suggestions for how the PR can be improved).

If one sets the --enable-activated-lora then the following happens:

  • At model loading time, we replace the QKV projection layers with equivalent aLoRA implementation (right now, the aLoRA weights change QKV projection layers only).
  • Each aLoRA request that comes in will be scanned for the invocation tokens and we store the invocation_start in the lora_request object.
  • When computing the hash of the blocks for prefix caching, the invocation_start information is used to determine whether base-model KV cache blocks can be re-used.
  • We introduce a simple ALoRAMetadata class that is needed to pass one mask tensor down to the LoRA layer.

We have tested that the integration works with:

  • Chunked prefill
  • Torch compile
  • Prefix caching
  • Multi-LoRA (can mix standard LoRA and aLoRA together)

Test Plan

We have included an offline example using an uncertainty detection aLoRA .

If the community would like to have this feature in vLLM, we are happy to add more extensive unit and integration tests.

Test Result

I've included some debug print statements in the scheduler to illustrate explicitly the KV cache re-use when applying the aLoRA:

$ python examples/alora/alora_offline_example.py
...
INFO 06-19 14:17:13 [scheduler.py:427] request_id:          0
INFO 06-19 14:17:13 [scheduler.py:428] num_tokens:          12
INFO 06-19 14:17:13 [scheduler.py:429] num_computed_tokens: 0
INFO 06-19 14:17:13 [scheduler.py:430] num_new_tokens:      12
Prompt: '<|start_of_role|>user<|end_of_role|>What is MIT?<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>', Generated text: '1. MIT, or Massachusetts Institute of Technology, is a prestigious private research university located in Cambridge, Massachusetts, USA.\n2. It was founded in 1861 and is known for its strong programs in science, technology, engineering, and mathematics (STEM).\n3. MIT is often ranked as one of the top universities globally and is a member of the Ivy League.\n4. It is renowned for its innovative research, influential faculty, and notable alumni.'
WARNING 06-19 14:17:14 [tokenizer.py:295] No tokenizer found in /home/zrltpa/.cache/huggingface/hub/models--ibm-granite--granite-3.2-8b-alora-uncertainty/snapshots/0d8ce48cdd4280a1e8fc37aa1de07537670ecf21, using base model tokenizer instead. (Exception: <class 'transformers.models.granite.configuration_granite.GraniteConfig'>)
INFO 06-19 14:17:14 [scheduler.py:427] request_id:          1
INFO 06-19 14:17:14 [scheduler.py:428] num_tokens:          139
INFO 06-19 14:17:14 [scheduler.py:429] num_computed_tokens: 128
INFO 06-19 14:17:14 [scheduler.py:430] num_new_tokens:      11
Time: 0.5810742378234863
Prompt: '<|start_of_role|>user<|end_of_role|>What is MIT?<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>1. MIT, or Massachusetts Institute of Technology, is a prestigious private research university located in Cambridge, Massachusetts, USA.\n2. It was founded in 1861 and is known for its strong programs in science, technology, engineering, and mathematics (STEM).\n3. MIT is often ranked as one of the top universities globally and is a member of the Ivy League.\n4. It is renowned for its innovative research, influential faculty, and notable alumni.<|end_of_text|>\n<|start_of_role|>certainty<|end_of_role|>', Generated text: '85.75%'

(Optional) Documentation Update

tbd

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @tdoublep, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request delivers the foundational components for Activated LoRA (aLoRA) within the vLLM framework. It enables dynamic application of LoRA weights based on prompt content, introduces necessary metadata structures, optimizes KV cache usage for these requests, and provides comprehensive examples for testing and deployment.

Highlights

  • Activated LoRA (aLoRA) Implementation: This PR introduces the initial working implementation of Activated LoRA (aLoRA), a technique that allows for selective application of LoRA weights based on specific 'invocation tokens' within the prompt. This enables dynamic switching between base model and LoRA weights during inference.
  • Core aLoRA Logic: The central mechanism for aLoRA is implemented in vllm/lora/layers.py, where a mask is dynamically generated based on k_offsets and query_start_locs. This mask determines which parts of the output should use the base model's computations and which should incorporate the LoRA adjustments, effectively blending the two outputs.
  • KV Cache Optimization for aLoRA: To optimize performance, the KV cache utility (vllm/v1/core/kv_cache_utils.py) has been updated to allow KV cache sharing for aLoRA requests. Specifically, the portion of the prompt before the aLoRA invocation tokens is treated as a base model request for caching purposes, reducing redundant computations.
  • Metadata and Request Handling: New fields (invocation_tokens, k_offset) have been added to LoRARequest to define the aLoRA activation. A new ALoRAMetadata dataclass is introduced in vllm/forward_context.py to pass these activation-specific details through the model's forward pass. The engine processor and GPU model runner are updated to extract and utilize this metadata, including tokenizing invocation strings from adapter configurations.
  • Testing and Examples: New example scripts (examples/alora/alora_server_testing.py, alora_server_testing.sh, new_alora_testing.py) are provided to demonstrate how to set up and interact with a vLLM server running aLoRA, both via the OpenAI-compatible API and directly through the vLLM Python API.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added documentation Improvements or additions to documentation v1 labels Jun 16, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an initial implementation of Activated LoRA (aLoRA). The changes include adding new example scripts, modifying core components like the forward context, LoRA request, KV cache utilities, scheduler, processor, and GPU model runner to support aLoRA metadata extraction and application. The core logic for identifying the aLoRA invocation sequence and applying the mask seems correctly implemented. Feedback includes addressing a type mismatch in a metadata class, removing a debug print statement, and clarifying the purpose of layer registration in the compilation config.

@tdoublep tdoublep changed the title [Model] Initial working implementation of Activated LoRA [Model] Activated LoRA Jun 16, 2025
Co-authored-by: Greenewald <[email protected]>
Co-authored-by: Allison Li <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Copy link

mergify bot commented Jun 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tdoublep.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 19, 2025
tdoublep added 3 commits June 19, 2025 14:10
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
@tdoublep tdoublep marked this pull request as ready for review June 19, 2025 14:30
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
@tdoublep
Copy link
Member Author

tdoublep commented Sep 8, 2025

@hmellor @varun-sundar-rabindranath I've worked through all of the review comments.

I also made a small modification to enable aLoRA for all linear layers (if aLoRA is explicitly set by the user) via a mixin class. My initial implementation only enabled it for the QKV layer but based on feedback from @kgreenewald, it is also useful to be able to use aLoRA for other linear layers too.

Signed-off-by: Thomas Parnell <[email protected]>
Copy link

mergify bot commented Sep 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tdoublep.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 10, 2025
Copy link
Member

@hmellor hmellor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use the same name in the config and in engine args?

Comment on lines 1208 to 1213
# In transformers backend, x and output have extra batch dimension like
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
# therefore we need to flatten the batch dimensions.
if x.ndim == 3 and output.ndim == 3:
output = output.flatten(0, 1)
x = x.flatten(0, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not super familiar with the LoRA codepath, would this flattening need to be reversed or is it fine because flatten doesn't modify the original tensor?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Signed-off-by: Thomas Parnell <[email protected]>
tdoublep and others added 2 commits September 11, 2025 06:23
Co-authored-by: Harry Mellor <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
@tdoublep
Copy link
Member Author

Could we use the same name in the config and in engine args?

@hmellor Done. Also had to refactor a bit due to recent reorg. of the lora layers code

Copy link

mergify bot commented Sep 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tdoublep.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 11, 2025
@hmellor
Copy link
Member

hmellor commented Sep 25, 2025

I've pinged @jeejeelee for a review if they have time

@jeejeelee
Copy link
Collaborator

jeejeelee commented Sep 26, 2025

@tdoublep Thank you for contribution. Considering that LoRA has many variants, we can probably only maintain and support some commonly used features. I'm not sure whether we should consider this PR. As you can see, we've also been cleaning up some LoRA-related niche features recently.

csqaiub added a commit to csqaiub/peft that referenced this pull request Sep 28, 2025
This PR migrates Activated LoRA (aLoRA) support from a standalone Github (see above) to PEFT itself.

Note there is also an active PR for vLLM inference support for Activated LoRA: vllm-project/vllm#19710 . There are also collections of aLoRA models on huggingface (in the ibm-granite org), note that these preexisting models run off of the standalone github repo and will be updated to work with this new PEFT feature if merged.

Description of changes: Activated LoRA is a modification of the LoRA architecture to "activate" the adapter weights only on tokens coming after a specified invocation_string. This fact makes it so that KV values for the string coming before the activation matches KV values for the base model. This allows KV cache for the input to be interchangeable between the base model and adapter model, and allows for major speedups in inference pipelines (e.g. agentic pipelines) that want to use both base models and adapter models. See the paper for detailed exploration of use cases and further elaboration.

Other notes:

The crux of the changes are really in layer.py. Everything else is simply managing the alora_offsets quantity which defines where the weights start to be activated. This is determined by scanning input strings for the invocation_string defined in the aLoraConfig.
    
I believe that aLoRA really only makes sense for CausalLMs, hence I've only implemented this for that model type.

Merging doesn't make sense for aLoRA adapters since the weights are not universally applied to all tokens.
    
I used the LoRA code as a starting point, but did not implement various seemingly extra features in that code.

As of now, invocation_string should probably start and end with special tokens, to avoid tokenizer issues at the boundary. Open to suggestions on how to make this more general if needed.

---------

Co-authored-by: githubnemo <[email protected]>
@lastras
Copy link

lastras commented Sep 29, 2025

@jeejeelee Thank you for your comment. As stated earlier, this pull request is designed precisely around enhancing the value proposition that VLLM gives to users as an optimized model inference platform. While the contribution we have made in here is focused on cache re-use (with the corresponding computation and memory savings), we have seen the same basic idea in activated LoRA also emerge in the context of multi-token-prediction, suggesting broader applicability in having flexibility in how adapters are applied during inference.

It is understandable to focus a subset of capabilities; perhaps we could entertain here a discussion on how to adapt this pull request in a way that is architecturally appealing to VLLM for maximum use/reuse/impact.

We would welcome any suggestions you may have

@lallison2
Copy link

Adding to the value proposition, our additional testing has verified that activated LoRA brings significant latency savings in multi-turn conversations, with speedups scaling by model size. For example, using aLoRA over regular LoRA in a simple base call - adapter call pattern within vLLM can result in up to a 58x end-to-end speedup for the adapter call (using Mistral Large 2 with an initial prompt length of 65k)-- with savings across all three of queue, prefill, and decode time. The following tests were run with varying initial prompt lengths, base model generation length of 256 tokens, and adapter evaluation length of 16 tokens. Granite-3.2-8B used 1 H100 GPU, Llama-3.3-70B used 4, and Mistral-Large 2 used 8.
e2e_latency_speedup_factor_prompt_len-evalttft_latency_speedup_factor_prompt_len-eval

prefill_time_speedup_factor_prompt_len-evaldecode_time_speedup_factor_prompt_len-eval
Stable vLLM queue time for longer prompt lengths is a major contributor to the large TTFT speedups (up to 100x for Mistral Large 2 with the longest prompt length).
queue_time_prompt_len-evalqueue_time_speedup_factor_prompt_len-eval

@jeejeelee
Copy link
Collaborator

Sorry for the late feedback caused by the holiday. I'll ask maintainers for their opinions

@lastras
Copy link

lastras commented Oct 8, 2025

Thank you @jeejeelee.

@jeejeelee
Copy link
Collaborator

After discussing with the maintainer, we're not considering this PR for now. it can remain as a draft PR to help user understand the details of your work. Thank you again for your great work.
Also cc @simon-mo @youkaichao

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation frontend needs-rebase tool-calling v1

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

7 participants