-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
Fix typical acceptance sampler with correct recovered token ids #8562
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
29d29a8
6ba1b4f
2be977d
9202ce7
eea43b2
0b327b7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -80,7 +80,7 @@ def forward( | |
target_probs = target_with_bonus_probs[:, :-1] | ||
accepted = self._evaluate_accepted_tokens(target_probs, | ||
draft_token_ids) | ||
recovered_token_ids = self._replacement_token_ids(target_probs) | ||
recovered_token_ids = self._get_recovered_token_ids(target_probs) | ||
output_token_ids = self._create_output(accepted, recovered_token_ids, | ||
draft_token_ids, | ||
bonus_token_ids) | ||
|
@@ -148,16 +148,10 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): | |
accepted_mask = candidates_prob > threshold | ||
return accepted_mask | ||
|
||
def _replacement_token_ids(self, target_probs): | ||
def _get_recovered_token_ids(self, target_probs): | ||
""" | ||
Generate one replacement token ID for each sequence based on target | ||
probabilities. The replacement token is used as the fallback option | ||
if typical acceptance sampling does not accept any draft tokens for | ||
that particular sequence. | ||
|
||
This method computes the token IDs to be replaced by selecting the | ||
token with the highest probability for each sequence in the first | ||
position. The rest of the output is filled with -1. | ||
The recovered token ids will fill the first unmatched token | ||
by the target token. | ||
|
||
Parameters | ||
---------- | ||
|
@@ -168,13 +162,9 @@ def _replacement_token_ids(self, target_probs): | |
Returns | ||
------- | ||
torch.Tensor | ||
A tensor of shape (batch_size, k) with the replacement | ||
token IDs. Only the first column is set, and the rest of the | ||
columns are filled with -1. | ||
A tensor of shape (batch_size, k) with the recovered token | ||
ids which are selected from target probs. | ||
""" | ||
max_indices = torch.argmax(target_probs[:, 0, :], dim=1) | ||
output = -torch.ones((target_probs.shape[0], target_probs.shape[1]), | ||
dtype=self.token_id_dtype, | ||
device=target_probs.device) | ||
output[:, 0] = max_indices | ||
return output | ||
max_indices = torch.argmax(target_probs, dim=-1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was expecting this assertion here (https://sourcegraph.com/github.com/vllm-project/vllm/-/blob/tests/samplers/test_typical_acceptance_sampler.py?L407) to fail for the test test_accept_tokens_partially since now for sequences where we have some draft tokens accepted we would get an additional recovered token. So as per my understanding in the test we should get one more token in addition to the accepted 2 draft tokens. Any idea why this assertion is still passing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. Can you please update the test to verify this? |
||
|
||
return max_indices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add comments here that saying we only support k=1 for now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The optimization works only when k>1. If k=1, it will remain the same as previous.