8
8
from transformers import GenerationConfig , GenerationMixin
9
9
10
10
from vllm .model_executor .layers .sampler import Sampler
11
+ from vllm .model_executor .sampling_metadata import SamplingMetadata
11
12
from vllm .model_executor .utils import set_random_seed
12
13
from vllm .sequence import SamplingParams , SequenceData , SequenceGroupMetadata
13
14
from vllm .utils import Counter
@@ -54,6 +55,7 @@ def _do_sample(
54
55
sampler : MockLogitsSampler ,
55
56
model_runner : ModelRunner ,
56
57
sampling_params : SamplingParams ,
58
+ device : str ,
57
59
):
58
60
seq_group_metadata_list = []
59
61
prompt_lens = []
@@ -68,9 +70,12 @@ def _do_sample(
68
70
))
69
71
prompt_lens .append (seq_group_metadata_list [- 1 ].seq_data [0 ].get_len ())
70
72
71
- sampling_metadata = model_runner ._prepare_sample (seq_group_metadata_list ,
72
- prompt_lens ,
73
- subquery_lens = prompt_lens )
73
+ sampling_metadata = SamplingMetadata .prepare (
74
+ seq_group_metadata_list ,
75
+ prompt_lens ,
76
+ subquery_lens = prompt_lens ,
77
+ device = device ,
78
+ pin_memory = model_runner .pin_memory )
74
79
return sampler (logits = input_tensor , sampling_metadata = sampling_metadata )
75
80
76
81
@@ -85,7 +90,7 @@ def test_sampler_all_greedy(seed: int, device: str):
85
90
86
91
sampling_params = SamplingParams (temperature = 0 )
87
92
sampler_output = _do_sample (batch_size , fake_logits , sampler , model_runner ,
88
- sampling_params )
93
+ sampling_params , device )
89
94
expected = torch .argmax (fake_logits , dim = - 1 )
90
95
for i , sequence_output in enumerate (sampler_output ):
91
96
for nth_output in sequence_output .samples :
@@ -111,7 +116,7 @@ def test_sampler_all_random(seed: int, device: str):
111
116
n = random .randint (1 , 10 ),
112
117
)
113
118
sampler_output = _do_sample (batch_size , fake_logits , sampler , model_runner ,
114
- sampling_params )
119
+ sampling_params , device )
115
120
116
121
for i , sequence_output in enumerate (sampler_output ):
117
122
for nth_output in sequence_output .samples :
@@ -137,7 +142,7 @@ def test_sampler_all_random_seed(seed: int, device: str):
137
142
seed = random .randint (0 , 10000 ),
138
143
)
139
144
sampler_output = _do_sample (batch_size , fake_logits , sampler , model_runner ,
140
- sampling_params )
145
+ sampling_params , device )
141
146
142
147
for i , sequence_output in enumerate (sampler_output ):
143
148
for nth_output in sequence_output .samples :
@@ -160,10 +165,10 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str):
160
165
seed = random .randint (0 , 10000 ),
161
166
)
162
167
first_sampler_output = _do_sample (batch_size , fake_logits , sampler ,
163
- model_runner , sampling_params )
168
+ model_runner , sampling_params , device )
164
169
165
170
second_sampler_output = _do_sample (batch_size , fake_logits , sampler ,
166
- model_runner , sampling_params )
171
+ model_runner , sampling_params , device )
167
172
168
173
assert first_sampler_output == second_sampler_output
169
174
@@ -183,7 +188,8 @@ def test_sampler_all_beam(seed: int, device: str):
183
188
best_of = 2 ,
184
189
use_beam_search = True ,
185
190
)
186
- _do_sample (batch_size , fake_logits , sampler , model_runner , sampling_params )
191
+ _do_sample (batch_size , fake_logits , sampler , model_runner , sampling_params ,
192
+ device )
187
193
# no assertion here as I am not sure how to determine whether
188
194
# the outputs are expected - in other words, this just tests
189
195
# whether there are no exceptions in the sampler
@@ -443,10 +449,12 @@ def run_test_case(*,
443
449
"batch size" )
444
450
445
451
_ , fake_logits , sampler , model_runner = _prepare_test (batch_size )
446
- sampling_metadata = model_runner . _prepare_sample (
452
+ sampling_metadata = SamplingMetadata . prepare (
447
453
seq_group_metadata_list ,
448
454
prompt_lens = prompt_lens if prompt_lens else None ,
449
- subquery_lens = prompt_lens if prompt_lens else None )
455
+ subquery_lens = prompt_lens if prompt_lens else None ,
456
+ device = device ,
457
+ pin_memory = model_runner .pin_memory )
450
458
# the logits tensor is modified in-place by the sampler
451
459
_ = sampler (logits = fake_logits , sampling_metadata = sampling_metadata )
452
460
@@ -530,8 +538,12 @@ def test_sampler_mixed(seed: int, device: str):
530
538
prompt_lens .append (seq_group_metadata_list [- 1 ].seq_data [0 ].get_len ())
531
539
532
540
def test_sampling (model_runner : ModelRunner ):
533
- sampling_metadata = model_runner ._prepare_sample (
534
- seq_group_metadata_list , prompt_lens , subquery_lens = prompt_lens )
541
+ sampling_metadata = SamplingMetadata .prepare (
542
+ seq_group_metadata_list ,
543
+ prompt_lens ,
544
+ subquery_lens = prompt_lens ,
545
+ device = device ,
546
+ pin_memory = model_runner .pin_memory )
535
547
sampler_output = sampler (logits = fake_logits ,
536
548
sampling_metadata = sampling_metadata )
537
549
@@ -627,9 +639,12 @@ def test_sampler_top_k_top_p(seed: int, device: str):
627
639
))
628
640
prompt_lens .append (seq_group_metadata_list [- 1 ].seq_data [0 ].get_len ())
629
641
630
- sampling_metadata = model_runner ._prepare_sample (seq_group_metadata_list ,
631
- prompt_lens ,
632
- subquery_lens = prompt_lens )
642
+ sampling_metadata = SamplingMetadata .prepare (
643
+ seq_group_metadata_list ,
644
+ prompt_lens ,
645
+ subquery_lens = prompt_lens ,
646
+ device = device ,
647
+ pin_memory = model_runner .pin_memory )
633
648
634
649
sample_probs = None
635
650
0 commit comments