@@ -98,8 +98,9 @@ def update_requests(self, state: SampleState) -> None:
98
98
99
99
100
100
def top_k_sampling_batch (logits , top_k = 50 ):
101
- logits_dim = logits .dim ()
102
- if logits_dim == 1 :
101
+ print (f"Debug INFO Apply top_k_sampling_batch with top_k:{ top_k } " )
102
+ # logits_dim = logits.dim()
103
+ if logits .dim () == 1 :
103
104
logits = logits .unsqueeze (0 )
104
105
# logits should be 2D :[batch_size, vocab_size]
105
106
batch_size , vocab_size = logits .size ()
@@ -121,6 +122,7 @@ def top_k_sampling_batch(logits, top_k=50):
121
122
122
123
123
124
def top_p_sampling_batch (logits : torch .Tensor , top_p : float = 0.9 ):
125
+ print (f"Debug INFO Apply top_p_sampling_batch with top_p:{ top_p } " )
124
126
logits_dim = logits .dim ()
125
127
if logits_dim == 1 :
126
128
logits = logits .unsqueeze (0 )
@@ -151,6 +153,15 @@ def top_p_sampling_batch(logits: torch.Tensor, top_p: float = 0.9):
151
153
return next_tokens , softmax
152
154
153
155
156
+ def temperature_sampling_batch (logits : torch .Tensor , temperature : float ):
157
+ print (f"Debug INFO Apply temperature_sampling_batch with temperature:{ temperature } " )
158
+ assert temperature > 0 , "Temperature must be positive (got {})" .format (temperature )
159
+ scaled_logits = logits / torch .tensor ([temperature ], device = logits .device ).unsqueeze (1 )
160
+ softmax_probs = torch .softmax (scaled_logits , dim = - 1 )
161
+ next_tokens = torch .multinomial (softmax_probs , num_samples = 1 ).squeeze (- 1 )
162
+ return next_tokens , softmax_probs
163
+
164
+
154
165
def greedy_search_sampling_batch (logits ):
155
166
next_tokens = torch .argmax (logits , dim = - 1 )
156
167
softmax = torch .softmax (logits , dim = - 1 )
@@ -159,18 +170,26 @@ def greedy_search_sampling_batch(logits):
159
170
160
171
TopK = tuple [Literal ["top_k" ], int ]
161
172
TopP = tuple [Literal ["top_p" ], float ]
173
+ Temperature = tuple [Literal ["temperature" ], float ]
162
174
Greedy = tuple [Literal ["greedy" ], None ]
163
175
GREEDY : Greedy = ("greedy" , None )
164
- Strategy = TopK | TopP | Greedy
176
+ Strategy = TopK | TopP | Greedy | Temperature
165
177
166
178
167
179
def request_strategy (request : LlmRequest ) -> Strategy :
168
- if request .sampling_config .top_p is not None and len (
169
- request .sampling_config .top_p ) > 0 :
170
- return ("top_p" , request .sampling_config .top_p [0 ])
171
- elif request .sampling_config .top_k is not None and len (
172
- request .sampling_config .top_k ) > 0 :
173
- return ("top_k" , request .sampling_config .top_k [0 ])
180
+ top_p = request .sampling_config .top_p [0 ] if request .sampling_config .top_p is not None and len (
181
+ request .sampling_config .top_p ) > 0 else None
182
+ top_k = request .sampling_config .top_k [0 ] if request .sampling_config .top_k is not None and len (
183
+ request .sampling_config .top_k ) > 0 else None
184
+ temperature = request .sampling_config .temperature [0 ] if request .sampling_config .temperature is not None and len (
185
+ request .sampling_config .temperature ) > 0 and request .sampling_config .temperature [0 ] > 0 else None
186
+
187
+ if top_p is not None and top_p != 1.0 :
188
+ return ("top_p" , top_p )
189
+ elif top_k is not None and top_k != 0 :
190
+ return ("top_k" , top_k )
191
+ elif temperature is not None and temperature != 1.0 :
192
+ return ("temperature" , temperature )
174
193
else :
175
194
return ("greedy" , None )
176
195
@@ -185,6 +204,8 @@ def sample(strategy: Strategy, logits: torch.Tensor):
185
204
return top_k_sampling_batch (logits , top_k )
186
205
case ("top_p" , top_p ):
187
206
return top_p_sampling_batch (logits , top_p )
207
+ case ("temperature" , temperature ):
208
+ return temperature_sampling_batch (logits , temperature )
188
209
case ("greedy" , None ):
189
210
return greedy_search_sampling_batch (logits )
190
211
@@ -404,29 +425,20 @@ def _process_requests(self,
404
425
num_steps = [1 + len (req .py_draft_tokens ) for req in requests ]
405
426
sum_steps = sum (num_steps )
406
427
no_draft_tokens = len (requests ) == sum_steps
407
- fast_path = not self .enable_mixed_sampler and no_draft_tokens and gen_logits_host is None and log_probs_host is None
408
428
409
429
seq_slots = torch .as_tensor ([r .seq_slot for r in requests ])
410
430
seq_slots = seq_slots .to (device = "cuda" , non_blocking = True )
411
431
412
- if fast_path :
413
- logits = raw_logits [:len (requests )]
414
- next_tokens = torch .argmax (logits , dim = - 1 )
415
- self .append_eagle3 (next_tokens , model_outputs )
416
- int_next_tokens = next_tokens .to (torch .int , non_blocking = True )
417
- next_tokens = int_next_tokens .view (1 , - 1 , beam_width )
418
- new_tokens [:1 ].index_copy_ (1 , seq_slots , next_tokens )
419
- return
420
-
421
432
strategies = sampling_strategies (requests )
422
433
batched_next_tokens , batched_softmax = None , None
423
434
batched_strategy : Strategy | None = GREEDY
424
435
if self .enable_mixed_sampler :
425
436
assert "d2t" not in model_outputs , "eagle3 does not yet support non-greedy sampling"
426
- if len (set (strategies )) == 1 :
427
- batched_strategy = strategies [0 ]
428
- else :
429
- batched_strategy = None
437
+
438
+ if len (set (strategies )) == 1 :
439
+ batched_strategy = strategies [0 ]
440
+ else :
441
+ batched_strategy = None
430
442
431
443
if batched_strategy is not None :
432
444
logits = raw_logits [:sum_steps ]
@@ -440,6 +452,7 @@ def _process_requests(self,
440
452
logits = raw_logits [input_slice ]
441
453
if batched_next_tokens is None :
442
454
next_tokens , softmax = sample (strategy , logits )
455
+ self .append_eagle3 (next_tokens , model_outputs )
443
456
else :
444
457
next_tokens = batched_next_tokens [input_slice ]
445
458
softmax = batched_softmax [input_slice ]
0 commit comments