Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions tests/v1/sample/test_rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,13 @@ def create_logits_tensor(token_ids: List[int],
def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
batch_size = len(spec_tokens)
return SamplingMetadata(
temperature=0.0,
temperature=torch.tensor([]),
all_greedy=True,
all_random=False,
rejection_sampling=True,
spec_token_ids=spec_tokens,
top_p=None,
top_k=None,
no_top_p=False,
no_top_k=False,
min_p=torch.empty(batch_size, ),
no_min_p=True,
generators={},
max_num_logprobs=0,
no_penalties=False,
Expand All @@ -45,8 +41,7 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
presence_penalties=torch.tensor([]),
repetition_penalties=torch.tensor([]),
output_token_ids=[],
min_tokens=[],
stop_token_ids=[],
min_tokens={},
logit_bias=[None] * batch_size,
)

Expand Down
44 changes: 18 additions & 26 deletions tests/v1/sample/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,25 +77,20 @@ def _create_default_sampling_metadata(
temperature=torch.full((batch_size, ), 0.0),
all_greedy=True,
all_random=False,
rejection_sampling=False,
top_p=torch.empty(batch_size, ),
top_k=torch.empty(batch_size, ),
no_top_p=True,
no_top_k=True,
min_p=torch.empty(batch_size, ),
no_min_p=True,
top_p=None,
top_k=None,
min_p=None,
generators={},
max_num_logprobs=0,
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
vocab_size, device),
output_token_ids=output_token_ids,
spec_token_ids=[],
spec_token_ids=None,
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
no_penalties=True,
min_tokens=[],
stop_token_ids=[],
min_tokens={},
logit_bias=[None] * batch_size,
)
return fake_sampling_metadata
Expand All @@ -104,33 +99,30 @@ def _create_default_sampling_metadata(
def _generate_min_token_penalties_and_stop_tokens(
num_output_tokens: int, batch_size: int, vocab_size: int,
batch_indices_for_min_token_penalty: List[int]
) -> Tuple[List[int], List[Set[int]]]:
) -> Dict[int, Tuple[int, Set[int]]]:
"""
Generates and returns a list of minimum token penalties (`min_tokens`)
and a corresponding list of stop token IDs (`stop_token_ids`) for each
Generates and returns a dict of minimum token penalties and
corresponding stop token IDs (`min_tokens`, `stop_token_ids`) for each
batch.

If a batch index is included in `batch_indices_for_min_token_penalty`,
a higher `min_tokens` value is assigned (within a randomized range),
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
"""
stop_token_ids: List[Set[int]] = []
min_tokens: List[int] = []
min_tokens: Dict[int, Tuple[int, Set[int]]] = {}
for index in range(batch_size):
if index in batch_indices_for_min_token_penalty:
min_tokens.append(
min_tokens[index] = (
np.random.randint(num_output_tokens + 1,
2 * num_output_tokens))
stop_token_ids.append(
2 * num_output_tokens),
set(
np.random.randint(0, vocab_size - 1)
for _ in range(np.random.randint(0, vocab_size))))

else:
min_tokens.append(np.random.randint(0, num_output_tokens))
stop_token_ids.append(set())
return (min_tokens, stop_token_ids)
min_tokens[index] = (np.random.randint(0,
num_output_tokens), set())
return min_tokens


def _create_weighted_output_token_list(
Expand Down Expand Up @@ -165,7 +157,7 @@ def _create_weighted_output_token_list(
output_token_ids_for_batch.extend(
[token_id for _ in range(index + 1)])
output_token_ids.append(output_token_ids_for_batch)
return (output_token_ids, sorted_token_ids_in_output)
return output_token_ids, sorted_token_ids_in_output


@pytest.mark.parametrize("device", CUDA_DEVICES)
Expand All @@ -182,17 +174,17 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int):
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
batch_indices_for_min_token_penalty = np.random.randint(
0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist()
min_tokens, stop_token_ids = _generate_min_token_penalties_and_stop_tokens(
min_tokens = _generate_min_token_penalties_and_stop_tokens(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE,
batch_indices_for_min_token_penalty)
sampling_metadata.min_tokens = min_tokens
sampling_metadata.stop_token_ids = stop_token_ids
sampler = Sampler()
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
for token_id in range(VOCAB_SIZE):
if token_id in stop_token_ids[batch_idx]:
_, stop_token_ids = min_tokens.get(batch_idx, (0, set()))
if token_id in stop_token_ids:
assert logits[batch_idx][token_id] == -float("inf")
else:
assert logits[batch_idx][token_id] != -float("inf")
Expand Down
47 changes: 21 additions & 26 deletions tests/v1/worker/test_gpu_input_batch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Dict, List, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple

import numpy as np
import pytest
Expand Down Expand Up @@ -41,7 +41,7 @@ def _remove_requests(
for index in req_indices_to_remove:
input_batch.remove_request(reqs[index].req_id)
req_ids_to_remove.add(reqs[index].req_id)
return (req_ids_to_remove, req_indices_to_remove_list)
return req_ids_to_remove, req_indices_to_remove_list


def _construct_expected_sampling_metadata(
Expand All @@ -64,8 +64,7 @@ def _construct_expected_sampling_metadata(
top_p = [0.0 for _ in range(num_reqs)]
min_p = [0.0 for _ in range(num_reqs)]
temperature = [0.0 for _ in range(num_reqs)]
stop_token_ids: List[Set[int]] = [set() for _ in range(num_reqs)]
min_tokens = [0 for _ in range(num_reqs)]
min_tokens = {}
logit_bias = [None] * num_reqs
for req in reqs:
if req.req_id not in req_ids_retained:
Expand All @@ -83,22 +82,21 @@ def _construct_expected_sampling_metadata(
top_p[index_in_input_batch] = req.sampling_params.top_p
min_p[index_in_input_batch] = req.sampling_params.min_p
temperature[index_in_input_batch] = req.sampling_params.temperature
stop_token_ids[
index_in_input_batch] = req.sampling_params.all_stop_token_ids
min_tokens[index_in_input_batch] = req.sampling_params.min_tokens
min_tokens[index_in_input_batch] = (
req.sampling_params.min_tokens,
req.sampling_params.all_stop_token_ids)
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
return SamplingMetadata(
temperature=torch.tensor(temperature, dtype=torch.float,
device=device),
all_greedy=False,
all_random=True,
rejection_sampling=False,
top_p=torch.tensor(top_p, dtype=torch.float, device=device),
top_k=torch.tensor(top_k, dtype=torch.int, device=device),
no_top_p=all(x == 1.0 for x in top_p),
no_top_k=all(x == 0 for x in top_k),
min_p=torch.tensor(min_p, dtype=torch.float, device=device),
no_min_p=all(x == 0.0 for x in min_p),
top_p=None if all(x == 1.0 for x in top_p) else torch.tensor(
top_p, dtype=torch.float, device=device),
top_k=None if all(x == 0 for x in top_k) else torch.tensor(
top_k, dtype=torch.int, device=device),
min_p=None if all(x == 0.0 for x in min_p) else torch.tensor(
min_p, dtype=torch.float, device=device),
generators={},
max_num_logprobs=0,
prompt_token_ids=make_tensor_with_pad(
Expand All @@ -117,9 +115,8 @@ def _construct_expected_sampling_metadata(
dtype=torch.float,
device=device),
output_token_ids=output_token_ids,
spec_token_ids=[],
spec_token_ids=None,
min_tokens=min_tokens,
stop_token_ids=stop_token_ids,
no_penalties=(all(x == 0 for x in presence_penalties)
and all(x == 0 for x in frequency_penalties)
and all(x == 1 for x in repetition_penalties)),
Expand Down Expand Up @@ -206,8 +203,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
input_batch.condense(req_indices_to_remove)

# Generate the sampling metadata
sampling_metadata = input_batch.make_sampling_metadata(
req_id_output_token_ids, req_id_to_spec_token_ids={}, skip_copy=False)
sampling_metadata = input_batch._make_sampling_metadata()

# Create expected output.
expected_sampling_metadata = _construct_expected_sampling_metadata(
Expand All @@ -216,13 +212,16 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
input_batch.req_id_to_index,
device=torch.device(device))

def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool:
return (t1 is None
and t2 is None) or (t1 is not None and t2 is not None
and torch.allclose(t1, t2))

# Assert the actual and expected output.
assert torch.allclose(expected_sampling_metadata.temperature,
sampling_metadata.temperature)
assert torch.allclose(expected_sampling_metadata.top_p,
sampling_metadata.top_p)
assert torch.allclose(expected_sampling_metadata.top_k,
sampling_metadata.top_k)
assert same(expected_sampling_metadata.top_p, sampling_metadata.top_p)
assert same(expected_sampling_metadata.top_k, sampling_metadata.top_k)
assert torch.allclose(
expected_sampling_metadata.frequency_penalties,
sampling_metadata.frequency_penalties,
Expand All @@ -240,10 +239,6 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
assert (expected_sampling_metadata.output_token_ids ==
sampling_metadata.output_token_ids)
assert expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens
assert expected_sampling_metadata.stop_token_ids == \
sampling_metadata.stop_token_ids
assert expected_sampling_metadata.no_penalties == \
sampling_metadata.no_penalties
assert expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p
assert expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k
assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias
33 changes: 22 additions & 11 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from vllm.sampling_params import SamplingParams
from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_model_runner import GPUModelRunner


Expand Down Expand Up @@ -82,14 +83,21 @@ def _is_req_added(model_runner, req_id: str) -> bool:
return req_id in model_runner.requests


def _is_sampling_metadata_changed(model_runner,
sampling_metadata_before: SamplingMetadata):
return model_runner.input_batch.sampling_metadata is not (
sampling_metadata_before)


def test_update_states_new_request(model_runner):
req_id = "req_0"

# new req
scheduler_output = _schedule_new_request(req_id)

batch_changed = model_runner._update_states(scheduler_output)
assert batch_changed is True
metadata_before = model_runner.input_batch.sampling_metadata
model_runner._update_states(scheduler_output)
assert _is_sampling_metadata_changed(model_runner, metadata_before)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)

Expand Down Expand Up @@ -117,8 +125,9 @@ def test_update_states_request_finished(model_runner):
free_encoder_input_ids=[],
)

batch_changed = model_runner._update_states(scheduler_output)
assert batch_changed is True
metadata_before = model_runner.input_batch.sampling_metadata
model_runner._update_states(scheduler_output)
assert _is_sampling_metadata_changed(model_runner, metadata_before)
assert not _is_req_added(model_runner, req_id)
assert not _is_req_scheduled(model_runner, req_id)

Expand All @@ -142,7 +151,7 @@ def test_update_states_request_resumed(model_runner):
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids={},
finished_req_ids=set(),
free_encoder_input_ids=[],
)

Expand Down Expand Up @@ -171,8 +180,9 @@ def test_update_states_request_resumed(model_runner):
free_encoder_input_ids=[],
)

batch_changed = model_runner._update_states(scheduler_output)
assert batch_changed is True
metadata_before = model_runner.input_batch.sampling_metadata
model_runner._update_states(scheduler_output)
assert _is_sampling_metadata_changed(model_runner, metadata_before)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)

Expand Down Expand Up @@ -200,8 +210,9 @@ def test_update_states_no_changes(model_runner):
free_encoder_input_ids=[],
)

batch_changed = model_runner._update_states(scheduler_output)
assert batch_changed is False
metadata_before = model_runner.input_batch.sampling_metadata
model_runner._update_states(scheduler_output)
assert not _is_sampling_metadata_changed(model_runner, metadata_before)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)

Expand Down Expand Up @@ -233,8 +244,8 @@ def test_update_states_request_unscheduled(model_runner):
free_encoder_input_ids=[],
)

batch_changed = model_runner._update_states(scheduler_output)
assert batch_changed is True
metadata_before = model_runner._update_states(scheduler_output)
assert _is_sampling_metadata_changed(model_runner, metadata_before)

assert _is_req_added(model_runner, req_ids[0])
assert _is_req_scheduled(model_runner, req_ids[0])
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
vocab_size, num_seqs)
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
output_tokens_tensor, vocab_size, num_seqs)
repetition_penalties = repetition_penalties.unsqueeze_(dim=1).repeat(
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
1, vocab_size)
logits[logits > 0] /= torch.where(prompt_mask | output_mask,
repetition_penalties, 1.0)[logits > 0]
logits[logits <= 0] *= torch.where(prompt_mask | output_mask,
repetition_penalties, 1.0)[logits <= 0]
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
return logits
6 changes: 4 additions & 2 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,10 @@ def schedule(self) -> "SchedulerOutput":
request.num_computed_tokens -
request.num_tokens)
if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens.
del request.spec_token_ids[num_scheduled_spec_tokens:]
scheduled_spec_decode_tokens[request.request_id] = (
request.spec_token_ids[:num_scheduled_spec_tokens])
request.spec_token_ids)
Comment on lines +199 to +201
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this change for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It avoids creating a new list, just trims the existing one down to num_scheduled_spec_tokens, since any later spec token ids are essentially discarded anyhow.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it! Maybe worth a comment.


# Encoder-related.
if encoder_inputs_to_schedule:
Expand Down Expand Up @@ -567,7 +569,7 @@ def update_from_output(
outputs.append(
EngineCoreOutput(
request_id=req_id,
new_token_ids=new_token_ids or [],
new_token_ids=new_token_ids,
finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
Expand Down
Loading