9
9
from typing import Any , List , Literal , Optional , cast
10
10
11
11
import torch
12
+ import torch .nn .functional as F
12
13
13
14
from tensorrt_llm ._torch .pyexecutor .make_decoding_batch_input_output import \
14
15
MakeDecodingBatchInputOutput
@@ -852,15 +853,19 @@ def _handle_stop_criteria(self, request: LlmRequest,
852
853
853
854
def handle_logprobs (self , request : LlmRequest , state : SampleState , * ,
854
855
beam : int , count : int ):
855
- current_slice = slice (0 , count ), request .py_seq_slot , beam
856
856
if request .py_return_log_probs :
857
- assert state .host .log_probs is not None
858
- log_probs = state .host .log_probs [request .py_seq_slot ][beam ][:count ]
859
- current_tokens = state .host .new_tokens [current_slice ]
857
+ topk_log_probs_vals = request .py_topk_logprobs_vals [:count ]
858
+ topk_log_probs_indices = request .py_topk_logprobs_indices [:count ]
860
859
861
860
token_log_probs = [{
862
- int (token ): Logprob (logprob = logprob , rank = 1 )
863
- } for token , logprob in zip (current_tokens , log_probs .tolist ())]
861
+ int (token ):
862
+ Logprob (logprob = logprob , rank = rank + 1 )
863
+ for rank , (token , logprob ) in enumerate (
864
+ zip (topk_token , topk_logprob .tolist ()))
865
+ }
866
+ for topk_token , topk_logprob in zip (
867
+ topk_log_probs_indices , topk_log_probs_vals )]
868
+
864
869
assert beam == 0 , "The following call relies on beam_width to be 1 - hence the list with a single element"
865
870
request .py_result .append_log_probs ([token_log_probs ])
866
871
@@ -970,13 +975,8 @@ def log_probs_host(
970
975
self ,
971
976
scheduled_requests : ScheduledRequests ) -> Optional [torch .Tensor ]:
972
977
"""Shape: In lockstep with TRTLLMSampler: https://github.com/NVIDIA/TensorRT-LLM/blob/cea5dd1e3883b18bf50901a7f196f50a9544c28c/cpp/include/tensorrt_llm/runtime/decoderState.h#L103"""
973
- if any (req .py_return_log_probs
974
- for req in scheduled_requests .all_requests ()):
975
- return torch .empty (
976
- (self .max_num_sequences , self .MAX_BEAM_WIDTH , self .max_tokens ),
977
- device = "cpu" ,
978
- pin_memory = True )
979
- return None
978
+ return any (req .py_return_log_probs
979
+ for req in scheduled_requests .all_requests ())
980
980
981
981
@override
982
982
@torch .inference_mode ()
@@ -1001,8 +1001,7 @@ def sample_async(self, scheduled_requests: ScheduledRequests,
1001
1001
sampler_event .record ()
1002
1002
return SampleState (scheduled_requests = scheduled_requests ,
1003
1003
device = SampleStateTensors (new_tokens = new_tokens ),
1004
- host = SampleStateTensors (new_tokens = new_tokens_host ,
1005
- log_probs = log_probs_host ),
1004
+ host = SampleStateTensors (new_tokens = new_tokens_host ),
1006
1005
sampler_event = sampler_event )
1007
1006
1008
1007
@staticmethod
@@ -1111,12 +1110,24 @@ def _sample_batched_by_strategy(
1111
1110
model_outputs : dict [str , torch .Tensor ],
1112
1111
* ,
1113
1112
cuda_device : torch .device ,
1114
- log_probs_host : torch . Tensor | None = None ,
1113
+ log_probs_host : bool = False ,
1115
1114
req_num_steps : torch .Tensor ,
1116
1115
req_offsets : torch .Tensor ,
1117
1116
steps_dim_size : int ,
1118
1117
token_dtype : torch .dtype ,
1119
1118
) -> _BatchedSamplingResult :
1119
+ if log_probs_host :
1120
+ assert logits_cuda .dim () == 2 , "logits should be 2D"
1121
+ logprobs = F .log_softmax (logits_cuda .to ("cuda" ,
1122
+ dtype = torch .float32 ),
1123
+ dim = - 1 )
1124
+ topk_vals , topk_indices = torch .topk (logprobs ,
1125
+ k = max (req .py_num_logprobs
1126
+ for req in requests ),
1127
+ dim = - 1 )
1128
+ topk_vals = topk_vals .to (device = "cpu" , non_blocking = True )
1129
+ topk_indices = topk_indices .to (device = "cpu" , non_blocking = True )
1130
+
1120
1131
requests_by_strategy = _group_requests_by_sampling_strategy (
1121
1132
requests , pin_memory = True )
1122
1133
generator_cuda = self .get_generator (cuda_device )
@@ -1160,12 +1171,18 @@ def _sample_batched_by_strategy(
1160
1171
# softmax_grp_indices: Indices of 'speculation_group_indices' entries requesting probs
1161
1172
# speculation_softmax_indices: Indices of 'softmax_grp_indices' entries corresponding
1162
1173
# to requests with draft logits.
1163
- if log_probs_host is not None :
1174
+ if log_probs_host :
1164
1175
softmax_req_indices = group_req_indices
1165
1176
softmax_grp_indices = torch .arange (len (group_req_indices ),
1166
1177
dtype = torch .int32 )
1167
1178
speculation_softmax_indices = torch .tensor (
1168
1179
speculation_group_indices , dtype = torch .int32 )
1180
+ for req_id in group_req_indices :
1181
+ req = requests [req_id ]
1182
+ req .py_topk_logprobs_vals = topk_vals [
1183
+ logits_cuda_indexer [req_id ], :req .py_num_logprobs ]
1184
+ req .py_topk_logprobs_indices = topk_indices [
1185
+ logits_cuda_indexer [req_id ], :req .py_num_logprobs ]
1169
1186
else :
1170
1187
speculation_group_indices_tensor = torch .tensor (
1171
1188
speculation_group_indices , dtype = torch .int32 )
@@ -1257,7 +1274,7 @@ def _unbatch_sampling_results(
1257
1274
new_tokens_cuda : torch .Tensor ,
1258
1275
req_num_steps : torch .Tensor ,
1259
1276
seq_slots : torch .Tensor ,
1260
- log_probs_host : torch . Tensor | None = None ,
1277
+ log_probs_host : bool = False ,
1261
1278
) -> torch .Tensor :
1262
1279
beam = self .BEAM
1263
1280
assert beam == 0 , "beam_width != 1 not supported"
@@ -1274,17 +1291,6 @@ def _dims_canonically_ordered(t: torch.Tensor) -> bool:
1274
1291
# Assert destination tensor dimensions are canonically ordered ("row"-major); this
1275
1292
# matters for element ordering in the .view(...).scatter_(...) calls below.
1276
1293
assert _dims_canonically_ordered (new_tokens_cuda )
1277
- assert log_probs_host is None or _dims_canonically_ordered (
1278
- log_probs_host )
1279
-
1280
- # new_tokens_cuda indexed by
1281
- # slice(0, steps), slot, beam
1282
- # log_probs_host indexed by
1283
- # slot, beam, slice(0, steps)
1284
- # batch_... tensors indexed by slice(batch_req_index, batch_req_index + steps)
1285
- #
1286
- if log_probs_host is not None :
1287
- assert new_tokens_cuda .size (0 ) == log_probs_host .size (- 2 )
1288
1294
1289
1295
# Construct index mapping from slice indices of computed tensors
1290
1296
# (packed request_idx and step dimensions) to linearized indices
@@ -1306,39 +1312,7 @@ def _dims_canonically_ordered(t: torch.Tensor) -> bool:
1306
1312
0 , batch_dest_indices_1d_cuda ,
1307
1313
batch_next_tokens_cuda_int )
1308
1314
new_tokens_host = new_tokens_cuda .to ("cpu" , non_blocking = True )
1309
- # NB: In order to avoid a scatter_ on the host and the necessary D2H copy + synchronization,
1310
- # the 'step' and 'seq_slot' dimensions are unpacked on GPU and later asynchronously
1311
- # copied into the destination buffer. Note that this overwrites all 'step' and token slots for the
1312
- # requests in 'requests' (passed to _process_requests). In fact, the current implementation
1313
- # even overwrites the destination tensors completely (including slices corresponding to request
1314
- # slots not present in 'requests', cf. 'FIXME' below).
1315
- if log_probs_host is not None :
1316
- # FIXME: If log_probs_host were indexed by request indices, rather than request slots, this
1317
- # tensor could be packed densely along the request axis.
1318
- log_probs_cuda = torch .empty_like (
1319
- log_probs_host , device = batch_dest_indices_1d_cuda .device )
1320
- # FIXME: Needs a separate indexer because tensor layout differs from new_tokens_cuda
1321
- batch_dest_probs_cuda_indexer = _UnpackedStepIndexer (
1322
- seq_slots = seq_slots [batch_req_indices ],
1323
- num_steps = req_num_steps [batch_req_indices ],
1324
- steps_dim_size = new_tokens_cuda .size (0 ),
1325
- slots_dim_size = new_tokens_cuda .size (1 ),
1326
- dim_order = _UnpackedStepIndexer .DimOrder .SLOT_MAJOR ,
1327
- index_dtype = torch .int64 , # enforced by Tensor.scatter_
1328
- )
1329
- batch_dest_probs_indices_cuda = batch_dest_probs_cuda_indexer [:].to (
1330
- batch_softmax_cuda .device , non_blocking = True )
1331
- # NB: torch.arange is needed to enable "advanced indexing",
1332
- # cf. https://numpy.org/devdocs/user/basics.indexing.html#integer-array-indexing
1333
- batch_token_probs = batch_softmax_cuda [
1334
- torch .arange (batch_softmax_cuda .size (0 ),
1335
- device = batch_softmax_cuda .device ,
1336
- dtype = torch .int32 ), batch_next_tokens_cuda_int ]
1337
- log_probs_cuda [:, beam ,
1338
- ...].view (- 1 , * log_probs_cuda .shape [3 :]).scatter_ (
1339
- 0 , batch_dest_probs_indices_cuda ,
1340
- torch .log (batch_token_probs ))
1341
- log_probs_host .copy_ (log_probs_cuda , non_blocking = True )
1315
+
1342
1316
# For requests with LlmRequest.py_draft_logits, return py_target_probs
1343
1317
for request , batch_softmax_index_cuda in py_draft_logits_indices :
1344
1318
request .py_target_probs = batch_softmax_cuda [
@@ -1481,7 +1455,6 @@ def _process_requests(
1481
1455
1482
1456
logits_cuda = self ._apply_min_length_penalty (logits_cuda , requests ,
1483
1457
req_num_steps_list )
1484
-
1485
1458
# Perform sampling in batches
1486
1459
batched_sampling_result = self ._sample_batched_by_strategy (
1487
1460
logits_cuda ,
0 commit comments