Skip to content

Commit 654bac2

Browse files
committed
[TRTLLM-5271] feat: best_of/n for pytorch workflow
Signed-off-by: Olya Kozlova <[email protected]>
1 parent e42f5a9 commit 654bac2

File tree

6 files changed

+175
-42
lines changed

6 files changed

+175
-42
lines changed

cpp/include/tensorrt_llm/batch_manager/llmRequest.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,9 @@ class GenericLlmRequest
467467
initialize(req.getInputTokenIds(), req.getOutputConfig().returnLogProbs);
468468
}
469469

470+
GenericLlmRequest(GenericLlmRequest&& request) = default;
471+
GenericLlmRequest(GenericLlmRequest const& request) = default;
472+
470473
void setExcludeInputFromOutput(bool exclude)
471474
{
472475
mExcludeInputFromOutput = exclude;
@@ -2315,6 +2318,9 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
23152318
mKvCacheRetentionConfig = request.getKvCacheRetentionConfig();
23162319
}
23172320

2321+
LlmRequest(LlmRequest&& request) = default;
2322+
LlmRequest(LlmRequest const& request) = default;
2323+
23182324
/// @brief Create a Response from the current state of the request
23192325
/// @details Note that there is some dependency on the order of operations in this method. Modify with care!
23202326
/// @return An optional Response

cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ void initBindings(pybind11::module_& m)
194194
.def_property_readonly("missed_blocks", &GenLlmReq::getMissedBlocksPerRequest)
195195
.def_property_readonly("kv_cache_hit_rate", &GenLlmReq::getKVCacheHitRatePerRequest)
196196
.def_property_readonly("llm_request_type", &GenLlmReq::getLlmRequestType)
197+
.def_property_readonly("parent_request_id", &GenLlmReq::getParentRequestId)
198+
.def_property_readonly("is_child", &GenLlmReq::isChild)
197199
.def_property_readonly("multimodal_hashes",
198200
[](GenLlmReq& self)
199201
{
@@ -256,7 +258,7 @@ void initBindings(pybind11::module_& m)
256258
.def_property_readonly("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics);
257259

258260
py::classh<tb::LlmRequest, GenLlmReq>(m, "LlmRequest", pybind11::dynamic_attr())
259-
.def(py::init(
261+
.def(py::init<>(
260262
[](tb::LlmRequest::RequestIdType request_id, tb::LlmRequest::SizeType32 max_new_tokens,
261263
std::vector<tb::LlmRequest::TokenIdType> input_tokens, runtime::SamplingConfig sampling_config,
262264
bool is_streaming, std::optional<tb::LlmRequest::SizeType32> end_id,
@@ -359,11 +361,14 @@ void initBindings(pybind11::module_& m)
359361
py::arg("return_perf_metrics") = false, py::arg("guided_decoding_params") = std::nullopt,
360362
py::arg("language_adapter_uid") = std::nullopt, py::arg("allotted_time_ms") = std::nullopt,
361363
py::arg("context_phase_params") = std::nullopt)
364+
.def(py::init<tb::LlmRequest const&>())
365+
//.def(py::init<tb::LlmRequest&&>())
362366
.def("validate", &tb::LlmRequest::validate, py::arg("max_input_len"), py::arg("max_seq_len"),
363367
py::arg("max_draft_len"), py::arg("vocab_size_padded"), py::arg("max_endocer_input_len") = std::nullopt,
364368
py::arg("enable_kv_cache_reuse") = false)
365369
.def("create_response", &tb::LlmRequest::createResponse, py::arg("use_fast_logits") = false,
366370
py::arg("mpi_world_rank") = 0)
371+
.def("create_child_request", &tb::LlmRequest::createChildRequest, py::arg("child_id"))
367372
.def("create_result", &tb::LlmRequest::createResult, py::arg("use_fast_logits") = false,
368373
py::arg("mpi_world_rank") = 0)
369374
.def("create_serialized_result",

examples/llm-api/quickstart_advanced.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ def add_llm_args(parser):
104104
parser.add_argument("--top_k", type=int, default=None)
105105
parser.add_argument("--top_p", type=float, default=None)
106106
parser.add_argument('--load_format', type=str, default='auto')
107+
parser.add_argument('--n', type=int, default=1)
108+
parser.add_argument('--best_of', type=int, default=None)
107109
parser.add_argument('--max_beam_width', type=int, default=1)
108110

109111
# Speculative decoding
@@ -229,10 +231,12 @@ def setup_llm(args):
229231
temperature=args.temperature,
230232
top_k=args.top_k,
231233
top_p=args.top_p,
234+
best_of=args.max_beam_width
235+
if args.max_beam_width > 1 else args.best_of,
232236
return_context_logits=args.return_context_logits,
233237
return_generation_logits=args.return_generation_logits,
234238
logprobs=args.logprobs,
235-
n=args.max_beam_width,
239+
n=args.n,
236240
use_beam_search=args.max_beam_width > 1)
237241
return llm, sampling_params
238242

@@ -246,23 +250,23 @@ def main():
246250

247251
for i, output in enumerate(outputs):
248252
prompt = output.prompt
249-
for beam_idx, beam in enumerate(output.outputs):
250-
generated_text = beam.text
253+
for sequence_idx, sequence in enumerate(output.outputs):
254+
generated_text = sequence.text
251255
# Skip printing the beam_idx if no beam search was used
252-
beam_id_text = f"[{beam_idx}]" if args.max_beam_width > 1 else ""
256+
sequence_id_text = f"[{sequence_idx}]" if args.max_beam_width > 1 or args.n > 1 else ""
253257
print(
254-
f"[{i}]{beam_id_text} Prompt: {prompt!r}, Generated text: {generated_text!r}"
258+
f"[{i}]{sequence_id_text} Prompt: {prompt!r}, Generated text: {generated_text!r}"
255259
)
256260
if args.return_context_logits:
257261
print(
258-
f"[{i}]{beam_id_text} Context logits: {output.context_logits}"
262+
f"[{i}]{sequence_id_text} Context logits: {output.context_logits}"
259263
)
260264
if args.return_generation_logits:
261265
print(
262-
f"[{i}]{beam_id_text} Generation logits: {beam.generation_logits}"
266+
f"[{i}]{sequence_id_text} Generation logits: {sequence.generation_logits}"
263267
)
264268
if args.logprobs:
265-
print(f"[{i}]{beam_id_text} Logprobs: {beam.logprobs}")
269+
print(f"[{i}]{sequence_id_text} Logprobs: {sequence.logprobs}")
266270

267271

268272
if __name__ == '__main__':

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -277,22 +277,28 @@ def __init__(
277277
exclude_last_generation_logits: bool = False,
278278
return_perf_metrics: bool = False,
279279
stop_words_list: list[list[int]] | None = None,
280+
llm_request: Optional[
281+
tensorrt_llm.bindings.internal.batch_manager.LlmRequest] = None,
280282
is_draft: bool = False,
281283
**kwargs):
284+
282285
self.py_logits_post_processors = kwargs.pop("py_logits_post_processors",
283286
None)
284287
# Multimodal data
285288
self.py_multimodal_data = kwargs.pop("py_multimodal_data", None)
286-
super().__init__(
287-
*args,
288-
client_id=client_id,
289-
return_log_probs=return_log_probs,
290-
return_context_logits=False,
291-
return_generation_logits=False,
292-
return_perf_metrics=return_perf_metrics,
293-
stop_words_list=torch.tensor(stop_words_list, dtype=torch.int32)
294-
if stop_words_list else None,
295-
**kwargs)
289+
if llm_request is not None:
290+
super().__init__(llm_request)
291+
else:
292+
super().__init__(
293+
*args,
294+
client_id=client_id,
295+
return_log_probs=return_log_probs,
296+
return_context_logits=False,
297+
return_generation_logits=False,
298+
return_perf_metrics=return_perf_metrics,
299+
stop_words_list=torch.tensor(stop_words_list, dtype=torch.int32)
300+
if stop_words_list else None,
301+
**kwargs)
296302
self.py_client_id = client_id
297303
self.py_request_id = self.request_id
298304
self.py_llm_request_type = self.llm_request_type
@@ -326,6 +332,7 @@ def __init__(
326332
return_log_probs, return_context_logits,
327333
return_generation_logits,
328334
exclude_last_generation_logits)
335+
self.children = []
329336

330337
def is_generation_only_request(self):
331338
return self.py_llm_request_type == LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY
@@ -337,7 +344,8 @@ def create_response(
337344
result, is_final = super().create_serialized_result(
338345
use_fast_logits, mpi_world_rank)
339346
return LlmResponse(
340-
request_id=self.py_request_id,
347+
request_id=self.py_request_id
348+
if self.is_child else self.parent_request_id,
341349
result=LlmResult(result, self.py_result, is_final),
342350
client_id=self.py_client_id) if len(result) > 0 else None
343351

@@ -350,6 +358,26 @@ def finish_by(self, reason: FinishReason, beam: int) -> None:
350358
self.state = LlmRequestState.GENERATION_COMPLETE
351359
self.set_finished_reason(reason, beam)
352360

361+
def create_child_request(self, child_id):
362+
child = super().create_child_request(child_id)
363+
py_request = LlmRequest(llm_request=child)
364+
py_request.__dict__.update(**self.__dict__)
365+
366+
py_request.py_result = PyResult(
367+
self.py_prompt_len, self.py_max_new_tokens,
368+
self.py_return_logits_device_memory, self.streaming,
369+
self.py_return_log_probs, self.py_return_context_logits,
370+
self.py_return_generation_logits)
371+
py_request.py_request_id = child.request_id
372+
py_request.children = []
373+
374+
assert py_request.is_child
375+
assert py_request.request_id == child.request_id
376+
assert py_request.parent_request_id == self.request_id
377+
assert py_request.sampling_config.random_seed != self.sampling_config.random_seed
378+
379+
return py_request
380+
353381

354382
def convert_wordlist(word_list) -> List[List[int]]:
355383
"""Converts a wordlist from format:
@@ -391,6 +419,7 @@ def convert_wordlist(word_list) -> List[List[int]]:
391419
def executor_request_to_llm_request(
392420
req_id: int,
393421
executor_request: ExecutorRequest,
422+
child_req_ids: List[int],
394423
exclude_last_generation_logits: bool,
395424
input_token_ids: Optional[List] = None) -> LlmRequest:
396425
executor_sampling_config = executor_request.sampling_config
@@ -475,4 +504,9 @@ def executor_request_to_llm_request(
475504
context_phase_params=executor_request.context_phase_params,
476505
py_multimodal_data=getattr(executor_request, "py_multimodal_data",
477506
None))
507+
if child_req_ids:
508+
for child_id in child_req_ids:
509+
llm_request.children.append(
510+
llm_request.create_child_request(child_id))
511+
478512
return llm_request

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
class RequestQueueItem:
5858
id: int
5959
request: Optional[ExecutorRequest] = None
60+
child_req_ids: Optional[list] = None
6061
is_canceled_request: bool = False
6162
query: Optional[list] = None # only used in `StarAttention`
6263

@@ -88,6 +89,13 @@ def _get_from_request_queue(
8889
return items
8990

9091

92+
def _get_num_child_requests(request: ExecutorRequest) -> int:
93+
sampling_config = request.sampling_config
94+
logger.info(sampling_config)
95+
return 0 if sampling_config.beam_width > 1 else (
96+
sampling_config.num_return_sequences or 1) - 1
97+
98+
9199
def _get_from_waiting_queue(
92100
waiting_queue: deque[RequestQueueItem],
93101
max_req_count: int,
@@ -108,8 +116,9 @@ def _get_from_waiting_queue(
108116
items = []
109117
req_count = 0
110118
while req_count < max_req_count and waiting_queue:
111-
items.append(waiting_queue.popleft())
112-
req_count += 1
119+
req_item = waiting_queue.popleft()
120+
items.append(req_item)
121+
req_count += 1 + _get_num_child_requests(req_item.request)
113122
return items
114123

115124

@@ -359,9 +368,16 @@ def enqueue_requests(self, requests: List[ExecutorRequest]):
359368
start_time = time.time()
360369
for request in requests:
361370
self.start_times[self.next_req_id] = start_time
362-
self.request_queue.put(
363-
RequestQueueItem(self.next_req_id, request))
371+
req_id = self.next_req_id
364372
req_ids.append(self.next_req_id)
373+
374+
child_req_ids = []
375+
num_child_requests = _get_num_child_requests(request)
376+
for _ in range(num_child_requests):
377+
self.next_req_id += 1
378+
child_req_ids.append(self.next_req_id)
379+
self.request_queue.put(
380+
RequestQueueItem(req_id, request, child_req_ids))
365381
self.next_req_id += 1
366382
finally:
367383
self.enqueue_lock.release()
@@ -472,14 +488,23 @@ def enqueue_request(self,
472488
try:
473489
self.enqueue_lock.acquire()
474490
assert self.active, "PyExecutor has already been shutdown."
491+
logger.info(
492+
f"Enqueuing new Executor request with id {self.next_req_id}")
475493
req_id = self.next_req_id
476494
if self.enable_iter_perf_stats:
477495
self.start_times[req_id] = time.time()
478496

479497
if query is not None:
480-
self.request_queue.put(RequestQueueItem(req_id, request, query))
498+
self.request_queue.put(
499+
RequestQueueItem(req_id, request, [], False, query))
481500
else:
482-
self.request_queue.put(RequestQueueItem(req_id, request))
501+
child_req_ids = []
502+
num_child_requests = _get_num_child_requests(request)
503+
for _ in range(num_child_requests):
504+
self.next_req_id += 1
505+
child_req_ids.append(self.next_req_id)
506+
self.request_queue.put(
507+
RequestQueueItem(req_id, request, child_req_ids))
483508
self.next_req_id += 1
484509
finally:
485510
self.enqueue_lock.release()
@@ -1506,12 +1531,15 @@ def _merge_requests(self, new_requests: list[RequestQueueItem]):
15061531
else:
15071532
raise NotImplementedError(f'unsupport cp type {cp_type}')
15081533
else:
1509-
return [
1510-
executor_request_to_llm_request(
1511-
req_item.id, req_item.request,
1534+
req_with_children = []
1535+
for req_item in new_requests:
1536+
req = executor_request_to_llm_request(
1537+
req_item.id, req_item.request, req_item.child_req_ids,
15121538
self._should_exclude_last_generation_logits())
1513-
for req_item in new_requests
1514-
]
1539+
req_with_children.append(req)
1540+
for child in req.children:
1541+
req_with_children.append(child)
1542+
return req_with_children
15151543

15161544
@nvtx_range("_schedule")
15171545
def _schedule(self):
@@ -1977,7 +2005,7 @@ def _handle_canceled_requests(self):
19772005
if req.id not in self.canceled_req_ids)
19782006

19792007
for request in self.active_requests:
1980-
req_id = request.py_request_id
2008+
req_id = request.py_request_id if not request.is_child else request.parent_request_id
19812009
if req_id in self.canceled_req_ids:
19822010
# Mark requests as finished, then, we reuse all existing code
19832011
# to clean up the KV cache resources.
@@ -1991,7 +2019,7 @@ def _handle_canceled_requests(self):
19912019
self.canceled_req_ids.clear()
19922020

19932021
@nvtx_range("_enqueue_responses")
1994-
def _enqueue_responses(self, responses: Dict[int, LlmResponse]):
2022+
def _enqueue_responses(self, responses: List[Tuple[int, LlmResponse]]):
19952023
if 0 not in self.dist.mapping.tp_group and not self.gather_all_responses:
19962024
return
19972025

@@ -2003,18 +2031,18 @@ def _enqueue_responses(self, responses: Dict[int, LlmResponse]):
20032031
else:
20042032
responses_list = self.dist.allgather(responses)
20052033
if self.dist.rank == 0 or self.gather_all_responses:
2006-
gather_responses = {}
2034+
gather_responses = []
20072035
if responses_list is not None:
20082036
for resp in responses_list:
20092037
if resp is not None:
2010-
gather_responses.update(resp)
2038+
gather_responses.append(resp)
20112039
responses = gather_responses
20122040
logger.debug(
20132041
f'after gather, rank = {self.dist.rank}, responses = {responses}')
20142042

20152043
if self.dist.rank == 0 or self.gather_all_responses:
20162044
with self.response_cv:
2017-
for req_id, resp in responses.items():
2045+
for req_id, resp in responses:
20182046
if req_id in self.responses.keys():
20192047
self.responses[req_id].append(resp)
20202048
else:
@@ -2023,20 +2051,20 @@ def _enqueue_responses(self, responses: Dict[int, LlmResponse]):
20232051

20242052
@nvtx_range("_handle_first_token_response")
20252053
def _handle_first_token_response(self, scheduled_batch):
2026-
new_responses = {}
2054+
new_responses = []
20272055
for req in scheduled_batch.generation_requests:
20282056
if req.py_decoding_iter == 1:
20292057
logger.debug(
20302058
f'Send first token response for request {req.py_request_id}'
20312059
)
20322060
response = req.create_response(False, self.dist.rank)
2033-
new_responses.update({req.py_request_id: response})
2061+
new_responses.append((req.py_request_id, response))
20342062

20352063
self._enqueue_responses(new_responses)
20362064

20372065
@nvtx_range("_handle_responses")
20382066
def _handle_responses(self):
2039-
new_responses = {}
2067+
new_responses = []
20402068
requests_to_terminate = []
20412069
new_active_requests = []
20422070
logger.debug(
@@ -2070,14 +2098,17 @@ def _handle_responses(self):
20702098
request.py_decoding_iter % self.stream_interval == 0:
20712099
response = request.create_response(False, self.dist.rank)
20722100
if response:
2073-
request_done = response.result.is_final
2074-
new_responses.update({req_id: response})
2101+
request_done = request.is_finished
2102+
new_responses.append((req_id, response))
20752103

20762104
if request_done:
20772105
if request.is_disagg_context_transmission_state:
20782106
self.ctx_in_transmission_requests.append(request)
20792107
else:
2080-
requests_to_terminate.append(request)
2108+
if response.result.is_final:
2109+
requests_to_terminate.append(request)
2110+
for child in request.children:
2111+
requests_to_terminate.append(child)
20812112
else:
20822113
new_active_requests.append(request)
20832114
self.active_requests = new_active_requests

0 commit comments

Comments
 (0)