-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
Re-enable prefill of max model length #24446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Re-enable prefill of max model length #24446
Conversation
Signed-off-by: Yannick Schnider <[email protected]>
Signed-off-by: Yannick Schnider <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request relaxes an assertion to re-enable prefilling up to the maximum model length and sampling a single token. While the intent is correct, the change as-is will likely cause an IndexError
because the underlying buffer for token IDs is not large enough to accommodate the extra token. A fix is required in vllm/v1/worker/gpu_input_batch.py
(and likely vllm/v1/worker/tpu_input_batch.py
) to increase the buffer size.
Signed-off-by: Yannick Schnider <[email protected]>
@WoosukKwon @LucasWilkinson tagging you guys here as author/reviewer of #20291 |
Signed-off-by: Yannick Schnider <[email protected]>
Signed-off-by: Yannick Schnider <[email protected]>
Signed-off-by: Yannick Schnider <[email protected]>
self.input_batch.token_ids_cpu[req_idx, | ||
start_idx:end_idx] = sampled_ids | ||
self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True | ||
assert end_idx <= self.max_model_len + 1, ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert end_idx <= self.max_model_len + 1
should fix the immediate issue and probably works with self.max_model_len - 1 - request.num_computed_tokens
.
I’m a bit stuck because:
- One place adds +1, another -1 — I feel like they cancel out, so maybe this isn’t the real root cause.
- From the vLLM module side, I don’t think the runner should care too much about how
max_model_len
is calculated upstairs. The assert is mostly just a safeguard.
@vadimkantorov brought up a deeper question: why is this assert even triggered? Looking at the call chain, it seems something unexpected happens in the schedule part (my PR isn’t addressing that).
Also, I really like the unit test you added — maybe we can team up and dig into the root cause together. 👍 @yannicks1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that (getting rid of the -1) is exactly what I will address in a follow up PR (have the changes working locally already)!
This PR is about prefill of max model length only.
For the decode stopping condition I just recently merged a PR in hugging face which allows one last decode on the max model length of context before emitting the warning (HF does simply truncate context, not stopping generation like vLLM).
Did split this into two separate PRs: 1st (this one) re-enabling prefill of max model length directly addressing the assert failure introduced in #20291, 2nd (builds on top of this one) allowing one last decode on max model length of context (that's where the getting rid of the -1 will happen, along minor other changes).
Reasons for splitting this is a) making the PRs smaller and easier to review, b) being consistent with HF (my HF PR just got merged into main this week, probably not in a release yet).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’m a bit stuck because:
- One place adds +1, another -1 — I feel like they cancel out, so maybe this isn’t the real root cause.
For prefill this part of the code is untouched. This is only for running sequences (decodes). So there is no + 1 - 1 happening in my unit test. As I mentioned above for decodes on the max model length this -1 will be gone (not the only change). I can share the branch later today for clarification...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vadimkantorov brought up a deeper question: why is this assert even triggered? Looking at the call chain, it seems something unexpected happens in the schedule part (my PR isn’t addressing that).
I my case the assertion is triggered when doing a prefill of max_model_len
and requesting 1 output token. I would be surprised if you triggered it another way? @vadimkantorov
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably the same on my side: max_model_len = 1024, and prompt_len happened to be 1023 probably or something similar. If a fix is out, I can try it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vadimkantorov you can use this branch to run your workload. it should fix your issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nicole-lihui here is the branch with the 2nd part (addressing decode): yannicks1#4
I will open the PR to vllm upstream once this PR is merged (currently it is targeting this branch to highlight the diffs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work! Our PRs seems complementary. I'll take inspiration from yours test and check if concurrency issue might show up.
Signed-off-by: Yannick Schnider <[email protected]>
hey @tdoublep I addressed all of your feedback. |
@vadimkantorov have you been able to confirm that your workload does not throw the assertion error with this branch? |
Signed-off-by: Yannick Schnider <[email protected]>
tests/v1/e2e/test_context_length.py
Outdated
@pytest.mark.parametrize("model", ["JackFram/llama-160m"]) | ||
@pytest.mark.parametrize("max_model_len", [2048]) | ||
@pytest.mark.parametrize("max_tokens", [1]) | ||
def test_models( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we give the test a more descriptive name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - thanks for catching this regression and adding the test
Signed-off-by: Yannick Schnider <[email protected]>
Signed-off-by: Yannick Schnider <[email protected]>
Signed-off-by: Yannick Schnider <[email protected]> Signed-off-by: yewentao256 <[email protected]>
Signed-off-by: Yannick Schnider <[email protected]> Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Yannick Schnider <[email protected]> Signed-off-by: Karan Goel <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Yannick Schnider <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Re-enable prefill at max model length
Purpose
closes #25120.
Before #20291 it was possible to prefill the model's context to
max_model_len
and then request a single new token. The change in #20291 added an assertion that prevents this:That assertion causes a failure when the prompt already consumes the full context and we sample one token. This PR restores the previous behavior (allowing a prefill to
max_model_len
and then sampling a single token), which matches the behaviour of HuggingFace Transformers.Proposed change
Relax the assertion/check so that a single sampled token after a prefill that exactly equals
max_model_len
is allowed. In short: allow the runner to return one new token when the prefill already fills the model's maximum context length.This restores parity with the HuggingFace transformers behaviour and avoids rejecting otherwise-valid generation requests that only ask for one additional token beyond a full prefill.
Test Plan
Add an end-to-end test that compares vLLM to HuggingFace Transformers:
max_model_len
, then:max_tokens=1
.The test is parametrized to make it easy to extend to other models / lengths later; the provided version uses
JackFram/llama-160m
,max_model_len=2048
andmax_tokens=1
.Failing behavior (before this PR)
Without the change the unit test fails with the assertion raised by the runner. Example failure seen during testing:
Test result (after this PR)