Skip to content

Commit d5a9f64

Browse files
xq25478huwen.hu@antgroup.com
authored andcommitted
imp(torchsampler):support sample params teemperature/topp/topk
Signed-off-by: xq25478 <[email protected]>
1 parent cfcb97a commit d5a9f64

File tree

1 file changed

+36
-23
lines changed

1 file changed

+36
-23
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,9 @@ def update_requests(self, state: SampleState) -> None:
9898

9999

100100
def top_k_sampling_batch(logits, top_k=50):
101-
logits_dim = logits.dim()
102-
if logits_dim == 1:
101+
print(f"Debug INFO Apply top_k_sampling_batch with top_k:{top_k}")
102+
# logits_dim = logits.dim()
103+
if logits.dim() == 1:
103104
logits = logits.unsqueeze(0)
104105
# logits should be 2D :[batch_size, vocab_size]
105106
batch_size, vocab_size = logits.size()
@@ -121,6 +122,7 @@ def top_k_sampling_batch(logits, top_k=50):
121122

122123

123124
def top_p_sampling_batch(logits: torch.Tensor, top_p: float = 0.9):
125+
print(f"Debug INFO Apply top_p_sampling_batch with top_p:{top_p}")
124126
logits_dim = logits.dim()
125127
if logits_dim == 1:
126128
logits = logits.unsqueeze(0)
@@ -151,6 +153,15 @@ def top_p_sampling_batch(logits: torch.Tensor, top_p: float = 0.9):
151153
return next_tokens, softmax
152154

153155

156+
def temperature_sampling_batch(logits: torch.Tensor, temperature: float):
157+
print(f"Debug INFO Apply temperature_sampling_batch with temperature:{temperature}")
158+
assert temperature > 0, "Temperature must be positive (got {})".format(temperature)
159+
scaled_logits = logits / torch.tensor([temperature], device=logits.device).unsqueeze(1)
160+
softmax_probs = torch.softmax(scaled_logits, dim=-1)
161+
next_tokens = torch.multinomial(softmax_probs, num_samples=1).squeeze(-1)
162+
return next_tokens, softmax_probs
163+
164+
154165
def greedy_search_sampling_batch(logits):
155166
next_tokens = torch.argmax(logits, dim=-1)
156167
softmax = torch.softmax(logits, dim=-1)
@@ -159,18 +170,26 @@ def greedy_search_sampling_batch(logits):
159170

160171
TopK = tuple[Literal["top_k"], int]
161172
TopP = tuple[Literal["top_p"], float]
173+
Temperature = tuple[Literal["temperature"], float]
162174
Greedy = tuple[Literal["greedy"], None]
163175
GREEDY: Greedy = ("greedy", None)
164-
Strategy = TopK | TopP | Greedy
176+
Strategy = TopK | TopP | Greedy | Temperature
165177

166178

167179
def request_strategy(request: LlmRequest) -> Strategy:
168-
if request.sampling_config.top_p is not None and len(
169-
request.sampling_config.top_p) > 0:
170-
return ("top_p", request.sampling_config.top_p[0])
171-
elif request.sampling_config.top_k is not None and len(
172-
request.sampling_config.top_k) > 0:
173-
return ("top_k", request.sampling_config.top_k[0])
180+
top_p = request.sampling_config.top_p[0] if request.sampling_config.top_p is not None and len(
181+
request.sampling_config.top_p) > 0 else None
182+
top_k = request.sampling_config.top_k[0] if request.sampling_config.top_k is not None and len(
183+
request.sampling_config.top_k) > 0 else None
184+
temperature = request.sampling_config.temperature[0] if request.sampling_config.temperature is not None and len(
185+
request.sampling_config.temperature) > 0 and request.sampling_config.temperature[0] > 0 else None
186+
187+
if top_p is not None and top_p != 1.0:
188+
return ("top_p", top_p)
189+
elif top_k is not None and top_k != 0:
190+
return ("top_k", top_k)
191+
elif temperature is not None and temperature != 1.0:
192+
return ("temperature", temperature)
174193
else:
175194
return ("greedy", None)
176195

@@ -185,6 +204,8 @@ def sample(strategy: Strategy, logits: torch.Tensor):
185204
return top_k_sampling_batch(logits, top_k)
186205
case ("top_p", top_p):
187206
return top_p_sampling_batch(logits, top_p)
207+
case ("temperature", temperature):
208+
return temperature_sampling_batch(logits, temperature)
188209
case ("greedy", None):
189210
return greedy_search_sampling_batch(logits)
190211

@@ -404,29 +425,20 @@ def _process_requests(self,
404425
num_steps = [1 + len(req.py_draft_tokens) for req in requests]
405426
sum_steps = sum(num_steps)
406427
no_draft_tokens = len(requests) == sum_steps
407-
fast_path = not self.enable_mixed_sampler and no_draft_tokens and gen_logits_host is None and log_probs_host is None
408428

409429
seq_slots = torch.as_tensor([r.seq_slot for r in requests])
410430
seq_slots = seq_slots.to(device="cuda", non_blocking=True)
411431

412-
if fast_path:
413-
logits = raw_logits[:len(requests)]
414-
next_tokens = torch.argmax(logits, dim=-1)
415-
self.append_eagle3(next_tokens, model_outputs)
416-
int_next_tokens = next_tokens.to(torch.int, non_blocking=True)
417-
next_tokens = int_next_tokens.view(1, -1, beam_width)
418-
new_tokens[:1].index_copy_(1, seq_slots, next_tokens)
419-
return
420-
421432
strategies = sampling_strategies(requests)
422433
batched_next_tokens, batched_softmax = None, None
423434
batched_strategy: Strategy | None = GREEDY
424435
if self.enable_mixed_sampler:
425436
assert "d2t" not in model_outputs, "eagle3 does not yet support non-greedy sampling"
426-
if len(set(strategies)) == 1:
427-
batched_strategy = strategies[0]
428-
else:
429-
batched_strategy = None
437+
438+
if len(set(strategies)) == 1:
439+
batched_strategy = strategies[0]
440+
else:
441+
batched_strategy = None
430442

431443
if batched_strategy is not None:
432444
logits = raw_logits[:sum_steps]
@@ -440,6 +452,7 @@ def _process_requests(self,
440452
logits = raw_logits[input_slice]
441453
if batched_next_tokens is None:
442454
next_tokens, softmax = sample(strategy, logits)
455+
self.append_eagle3(next_tokens, model_outputs)
443456
else:
444457
next_tokens = batched_next_tokens[input_slice]
445458
softmax = batched_softmax[input_slice]

0 commit comments

Comments
 (0)