-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[V1] - Split Prefill and Decode for Mamba1 models #22653
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: amirk <[email protected]>
Signed-off-by: amirk <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
The pull request refactors the Mamba1 mixer to handle prefill and decode operations using dedicated kernels, which also addresses issues with mixed batches. The changes involve splitting the forward pass logic and introducing new metadata to manage prefill and decode contexts. My review has identified a few critical issues related to correctness, particularly an incorrect variable assignment and a logic path that would fail for V0 models. There are also some type hint inconsistencies that should be addressed for better code clarity and maintainability.
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
@amirai21 Can you fix the pre-commit errors and address gemini's comments? I think there are some useful things in them. |
Also CC @tdoublep if you are interested in this PR. |
Signed-off-by: amirk <[email protected]>
@heheda12345 - thanks for reviewing! I’ve addressed Gemini’s comments and fixed the pre-commit errors. I’ve also added OpenGPT benchmarks for the changes. |
Signed-off-by: amirk <[email protected]>
ssm_state = self_kv_cache[1] | ||
has_initial_state = mamba1_metadata.has_initial_states | ||
context_lens_tensor = mamba1_metadata.context_lens_tensor | ||
has_initial_states_p = mamba1_metadata.has_initial_states |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just like mamba2, can mamba1_metadata.has_initial_states be a tensor with shape (num_prefills,)
? has_initial_states_p
is a bit confusing now because it is with shape (num_decode+num_prefill,)
context_lens_tensor: torch.Tensor | ||
state_indices_tensor: torch.Tensor | ||
has_initial_states: torch.Tensor | ||
num_prefills: int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should has_initial_state be Optional[torch.Tensor]
? (I'll fix those in Mamba2AttentionMetadata) And prefer it to be in shape (num_prefills,)
like in mamba2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix for mamba2: #22787
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Fixed
has_initial_states
type hint to Optional. - We return here the
has_initial_states
of shape(num_prefills + num_decodes,)
and in the mixersplit_batch_to_prefill_and_decode
take thenum_prefills
from it (via v0 / v1 order logic), then it becomeshas_initial_states_p
that's passed to the has_prefill flow kernels.
gate, | ||
state_indices_tensor, | ||
query_start_loc, | ||
has_initial_states_p, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you need has_initial_states_p=pefill_decode_split.has_initial_states_p
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we pass has_initial_states_p to causal_conv1d_fn / selective_scan_fn in prefill.
…tes to has_initial_states_p in attn md, fix typo Signed-off-by: amirk <[email protected]>
Signed-off-by: amirk <[email protected]>
Signed-off-by: amirk <[email protected]>
Signed-off-by: amirk <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Can you also run some evals (e.g., by lm_eval) to verify the correctness?
Thanks for reviewing. We run lm-evel: VLLM_USE_V1=1 HF_ALLOW_CODE_EVAL=1 lm_eval --model vllm --model_args pretrained=state-spaces/mamba-130m-hf,enforce_eager=True,enable_prefix_caching=False,tensor_parallel_size=1 --tasks humaneval --batch_size auto --confirm_run_unsafe_code
| Tasks |Version| Filter |n-shot|Metric| |Value | |Stderr|
|---------|------:|-----------|-----:|------|---|-----:|---|-----:|
|humaneval| 1|create_test| 0|pass@1| |0.0183|± |0.0105| We get the same results with main. |
@amirai21 CI failures seem to be related to this PR. Can you fix them? |
Signed-off-by: asafg <[email protected]>
@heheda12345 I updated the PR with the relevant test fixes. Thanks. I added |
Signed-off-by: amirk <[email protected]> Signed-off-by: asafg <[email protected]> Co-authored-by: asafg <[email protected]> Co-authored-by: Asaf Joseph Gardin <[email protected]>
Signed-off-by: amirk <[email protected]> Signed-off-by: asafg <[email protected]> Co-authored-by: asafg <[email protected]> Co-authored-by: Asaf Joseph Gardin <[email protected]>
Signed-off-by: amirk <[email protected]> Signed-off-by: asafg <[email protected]> Co-authored-by: asafg <[email protected]> Co-authored-by: Asaf Joseph Gardin <[email protected]> Signed-off-by: Duncan Moss <[email protected]>
Signed-off-by: amirk <[email protected]> Signed-off-by: asafg <[email protected]> Co-authored-by: asafg <[email protected]> Co-authored-by: Asaf Joseph Gardin <[email protected]>
Signed-off-by: amirk <[email protected]> Signed-off-by: asafg <[email protected]> Co-authored-by: asafg <[email protected]> Co-authored-by: Asaf Joseph Gardin <[email protected]> Signed-off-by: Xiao Yu <[email protected]>
Signed-off-by: amirk <[email protected]> Signed-off-by: asafg <[email protected]> Co-authored-by: asafg <[email protected]> Co-authored-by: Asaf Joseph Gardin <[email protected]>
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.Purpose
There was an issue where in the mamba1 mixer, the decode kernels weren't being called. The PR fixes this issue so now prefill and decodes are using the correct kernels, and so are mixed batches.
We tested ai21labs/AI21-Jamba-Mini-1.7 on a single H100-80GB GPU and observed improved performance on the ShareGPT dataset.
Running the model:
Running the serving benchmark:
main (50f2aae):
With split prefill decode changes:
Test Plan
Tests for mamba1 and Jamba pass in
test_hybrid.py
Test Result
(Optional) Documentation Update