-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[v1] - Mamba1 Attention Metadata #21249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces v1-style attention metadata support for Mamba-1 models and refactors the Mamba state shape calculation into a centralized MambaStateShapeCalculator
class. The refactoring improves code organization and maintainability. The v1 support is well-integrated, with clear logic separation based on the VLLM_USE_V1
environment variable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great job.
Some questions:
- Do we need
enforce-eager
to run mamba1? I'm OK with supporting cuda graph in a future PR. - Can you show some lm-eval result on mamba1 model?
- Please update
tests/models/language/generation/test_hybrid.py
andtests/v1/test_oracle.py
- Is Jamba supported now?
- Please update the doc like https://docs.vllm.ai/en/latest/usage/v1_guide.html#mamba-models and https://docs.vllm.ai/en/latest/usage/v1_guide.html#mamba-models
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How different is the Mamba1AttentionMetdata to the Mamba2AttentionMetadata? Do we really need two separate classes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question. I think we should keep both since Mamba2AttentionMetadata
contains fields that aren't relevant for mamba1 like chunk_indices
, chunk_offsets
and triton kernels related fields that mamba1 doesn't need which would add overhead to the class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer to keep mamba1 and mamba2 as seperate metadata classes. Comparing with one metadata class with many optional entry and branches for different types of layers, I prefer this pluggable design. We can extract common logic to a parent class if we find some after more models like minimax are added. @tdoublep WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@heheda12345 Sorry I missed this question. Yes I agree let's keep the metadata classes separate and then factor the common things into a CommonMambaAttentionMetadata
or something once it is clear what is truly common. This one is ready and we have the LFM2 and MiniMax-Text ones nearly there too, so we should be able to look at that soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer to keep mamba1 and mamba2 as seperate metadata classes. Comparing with one metadata class with many optional entry and branches for different types of layers, I prefer this pluggable design. We can extract common logic to a parent class if we find some after more models like minimax are added. @tdoublep WDYT?
This pull request has merge conflicts that must be resolved before it can be |
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: asafg <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a few (mostly minor) comments but this look nearly ready to go in my view. Great work.
) | ||
|
||
return conv_state_shape, temporal_state_shape | ||
class MambaStateShapeCalculator: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any real reason to introduce class MambaStateShapeCalculator
? Couldn't these just be different utils functions? It creates a lot of diff in the other files for little benefit as far as I can see. Right now it is making the PR look more intrusive than it really is (with 20 files changed).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason I added the class was because I needed to add a new function to calculate mamba1 state shape. The file already had a get_mamba_state_shape
function but it was mamba2 only, and I didn't want to introduce branching logic within it to handle both architectures.
I considered loose utility functions like get_mamba1_state_shape
, get_mamba2_state_shape
, but the class provides clearer grouping since:
- These functions are conceptually related (all calculate Mamba state shapes)
- It makes the API more discoverable - you know all state shape calculations are in one place
what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's reasonable. My main argument against was that the change touches a lot of files, but we would have to change the function name to mamba2
anyway which would create similar level of diff.
Signed-off-by: asafg <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - thanks for great work on this PR. Debugging the kernel-level stuff must have not so straightforward. Great that it works now and only needs to slightly extend the abstractions that were put in place for mamba2.
) | ||
|
||
return conv_state_shape, temporal_state_shape | ||
class MambaStateShapeCalculator: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's reasonable. My main argument against was that the change touches a lot of files, but we would have to change the function name to mamba2
anyway which would create similar level of diff.
params.ssm_states_batch_stride = ssm_states.stride(0); | ||
params.ssm_states_dim_stride = ssm_states.stride(1); | ||
params.ssm_states_dstate_stride = ssm_states.stride(2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this be pulled out of the if/else?
PR looks good. Thanks for adding Jamba and Mamba1 to V1! |
V1 test failure is known flaky test (see #22385), and the quantization test and blackwell tests look unrelated. This one is good to merge imo. |
Signed-off-by: asafg <[email protected]> Co-authored-by: asafg <[email protected]> Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: asafg <[email protected]> Co-authored-by: asafg <[email protected]> Signed-off-by: Noam Gat <[email protected]>
Signed-off-by: asafg <[email protected]> Co-authored-by: asafg <[email protected]>
Signed-off-by: asafg <[email protected]> Co-authored-by: asafg <[email protected]>
Signed-off-by: asafg <[email protected]> Co-authored-by: asafg <[email protected]> Signed-off-by: Diego-Castan <[email protected]>
Signed-off-by: asafg <[email protected]> Co-authored-by: asafg <[email protected]>
Signed-off-by: asafg <[email protected]> Co-authored-by: asafg <[email protected]> Signed-off-by: Xiao Yu <[email protected]>
Signed-off-by: asafg <[email protected]> Co-authored-by: asafg <[email protected]>
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.Purpose
Add full v1-style attention-metadata support to Mamba-1 models. The following was added
Mamba1AttentionMetadataBuilder
to create the attention metadata for mamba-1 modelsMambaStateShapeCalculator
aggregated mamba state shape calculation logic for all mamba models under the same static class for better readability and navigationselective_scan_fwd
to support V1 memory layoutTest Plan
Updated the following tests -
test_hybrid.py
- Now tests Mamba1 and Jamba in V1test_oracle.py
- Removed Mamba1 from unsupported V1 modelRunning all tests in
test_hybrid.py
,test_oracle.py
andtest_mamba_ssm.py
(due to kernel change) pass.Running mamba1 in main branch would raise an error
lm_eval results with
state-spaces/mamba-130m-hf
vLLM V0 with Mamba1
vLLM V1 with Mamba1
This PR now enables this vLLM V1 to work with models that use Mamba1 like Mamba1 and Jamba