Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions tests/samplers/test_typical_acceptance_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def test_accept_tokens_partially(seed: int, device: str):
# Next only keep the first 2 draft tokens same as the zero temperature
# tokens. For the remaining 3 choose some other tokens. In the
# response we will expect the first 2 tokens to be the same as the
# draft tokens and the rest as -1
# draft tokens and the recovered token and rest as -1
draft_token_ids_to_replace = get_draft_token_ids(
batch_size, k, vocab_size, zero_temperature_token_ids)
draft_token_ids = torch.cat(
Expand All @@ -378,6 +378,8 @@ def test_accept_tokens_partially(seed: int, device: str):
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2])
assert torch.all(
output_token_ids[:, 2] == target_with_bonus_probs.argmax(-1)[:, 2])
assert torch.all(output_token_ids[:, -3:] == -1)


Expand Down Expand Up @@ -443,14 +445,14 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, device: str):
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_replacement_token_ids(seed: int, device: str):
def test_get_recovered_token_ids(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler's method for generating
replacement token IDs.

This test verifies that the `_replacement_token_ids` method of the
This test verifies that the `_get_recovered_token_ids` method of the
TypicalAcceptanceSampler correctly identifies the token IDs to be used
as replacements based on the target probability distribution.
as recovered token IDs based on the target probability distribution.
Specifically, it ensures that the method correctly identifies the
tokens with the highest probability for each sequence in the batch.
"""
Expand All @@ -462,10 +464,7 @@ def test_replacement_token_ids(seed: int, device: str):
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
expected_replacement_tokens = -torch.ones(
(batch_size, k), dtype=torch.long)
expected_replacement_tokens[:, 0] = torch.argmax(target_probs[:, 0, :],
dim=1)
expected_replacement_tokens = torch.argmax(target_probs, dim=-1)
actual_replacement_tokens = (
typical_acceptance_sampler._replacement_token_ids(target_probs))
typical_acceptance_sampler._get_recovered_token_ids(target_probs))
assert torch.all(expected_replacement_tokens == actual_replacement_tokens)
28 changes: 9 additions & 19 deletions vllm/model_executor/layers/typical_acceptance_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def forward(
target_probs = target_with_bonus_probs[:, :-1]
accepted = self._evaluate_accepted_tokens(target_probs,
draft_token_ids)
recovered_token_ids = self._replacement_token_ids(target_probs)
recovered_token_ids = self._get_recovered_token_ids(target_probs)
output_token_ids = self._create_output(accepted, recovered_token_ids,
draft_token_ids,
bonus_token_ids)
Expand Down Expand Up @@ -148,16 +148,10 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids):
accepted_mask = candidates_prob > threshold
return accepted_mask

def _replacement_token_ids(self, target_probs):
def _get_recovered_token_ids(self, target_probs):
"""
Generate one replacement token ID for each sequence based on target
probabilities. The replacement token is used as the fallback option
if typical acceptance sampling does not accept any draft tokens for
that particular sequence.

This method computes the token IDs to be replaced by selecting the
token with the highest probability for each sequence in the first
position. The rest of the output is filled with -1.
The recovered token ids will fill the first unmatched token
by the target token.

Parameters
----------
Expand All @@ -168,13 +162,9 @@ def _replacement_token_ids(self, target_probs):
Returns
-------
torch.Tensor
A tensor of shape (batch_size, k) with the replacement
token IDs. Only the first column is set, and the rest of the
columns are filled with -1.
A tensor of shape (batch_size, k) with the recovered token
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add comments here that saying we only support k=1 for now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The optimization works only when k>1. If k=1, it will remain the same as previous.

ids which are selected from target probs.
"""
max_indices = torch.argmax(target_probs[:, 0, :], dim=1)
output = -torch.ones((target_probs.shape[0], target_probs.shape[1]),
dtype=self.token_id_dtype,
device=target_probs.device)
output[:, 0] = max_indices
return output
max_indices = torch.argmax(target_probs, dim=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was expecting this assertion here (https://sourcegraph.com/github.com/vllm-project/vllm/-/blob/tests/samplers/test_typical_acceptance_sampler.py?L407) to fail for the test test_accept_tokens_partially since now for sequences where we have some draft tokens accepted we would get an additional recovered token. So as per my understanding in the test we should get one more token in addition to the accepted 2 draft tokens. Any idea why this assertion is still passing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's because we didn't check the recovered token
image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Can you please update the test to verify this?


return max_indices
Loading