Skip to content

Commit d11975d

Browse files
fix eagle case
Signed-off-by: Jaedeok Kim <[email protected]>
1 parent a90b238 commit d11975d

File tree

4 files changed

+25
-4
lines changed

4 files changed

+25
-4
lines changed

examples/llm-api/quickstart_advanced.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,11 @@ def setup_llm(args):
191191
greedy_decoding = ((args.temperature == 0.0)
192192
or (args.top_k == 1 and
193193
(args.top_p == 0.0 or args.top_p is None)))
194-
mixed_sampler = not greedy_decoding and not args.enable_trtllm_sampler
194+
mixed_sampler = (
195+
not greedy_decoding and not args.enable_trtllm_sampler
196+
# Eagle3 does not support mixed sampler.
197+
# Refer TorchSampler._process_requests.
198+
and spec_decode_algo != 'EAGLE3')
195199

196200
cuda_graph_config = CudaGraphConfig(
197201
batch_sizes=args.cuda_graph_batch_sizes,

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,13 @@ def finish_by(request: Union[
288288
request.set_finished_reason(reason, beam)
289289

290290

291+
def is_generation_only_request(
292+
request: Union['LlmRequest',
293+
tensorrt_llm.bindings.internal.batch_manager.LlmRequest]
294+
) -> bool:
295+
return request.py_llm_request_type == LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY
296+
297+
291298
class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
292299
"""LlmRequest wraps `bindings.internal.batch_manager.LlmRequest`
293300
but detour some features to Python implementation"""
@@ -356,8 +363,10 @@ def __init__(
356363
return_generation_logits,
357364
exclude_last_generation_logits)
358365

359-
def is_generation_only_request(self):
360-
return self.py_llm_request_type == LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY
366+
def is_generation_only(self):
367+
# is_generation_only_request is a property getter at the C++ side,
368+
# so here use a different name at the pytorch backend.
369+
return is_generation_only_request(self)
361370

362371
def create_child_request(self, request_id: int):
363372
""" Create a child request.
@@ -386,11 +395,14 @@ def create_child_request(self, request_id: int):
386395
child_request.py_request_id = child_request.request_id
387396
child_request.py_llm_request_type = child_request.llm_request_type
388397
child_request.py_batch_idx = None
398+
child_request.py_seq_slot = None
389399

390400
# Mimic the behavior of the original LlmRequest.
391401
child_request.is_attention_dp_dummy = self.is_attention_dp_dummy
392402
child_request.is_cuda_graph_dummy = self.is_cuda_graph_dummy
393403
child_request.is_dummy = self.is_dummy
404+
child_request.is_generation_only = partial(is_generation_only_request,
405+
child_request)
394406
child_request.create_response = partial(create_response, child_request)
395407
child_request.finish_by = partial(finish_by, child_request)
396408

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2081,7 +2081,7 @@ def _handle_responses(self):
20812081
requests_to_terminate.append(request)
20822082
continue
20832083

2084-
if request.is_generation_only_request():
2084+
if request.is_generation_only():
20852085
# If request is in transmission, so we don't need to emit a response
20862086
# Also, for the first iteration with overlap, we should skip since first
20872087
# token has already been emitted previously

tests/unittest/_torch/test_llm_request.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ def test_create_response():
215215
child_request.state = LlmRequestState.GENERATION_IN_PROGRESS
216216

217217
response = request.create_response(use_fast_logits=True, mpi_world_rank=1)
218+
# The response having non-error result contain _result.
219+
response.result.deserialize()
218220
assert response is not None
219221
assert isinstance(response, LlmResponse)
220222
assert response.request_id == 1
@@ -224,6 +226,7 @@ def test_create_response():
224226
assert response.result.sequence_index == 0
225227

226228
child_response = child_request.create_response()
229+
child_response.result.deserialize()
227230
assert child_response is not None
228231
assert child_response.request_id == 2
229232
assert child_response.client_id == child_request.py_client_id
@@ -234,12 +237,14 @@ def test_create_response():
234237
child_request.state = LlmRequestState.GENERATION_COMPLETE
235238
# is_final=False since the parent request is not yet complete.
236239
child_response = child_request.create_response()
240+
child_response.result.deserialize()
237241
assert child_response.result.is_final is False
238242
assert child_response.result.is_sequence_final is True
239243

240244
# is_final=True since all requests are complete.
241245
request.state = LlmRequestState.GENERATION_COMPLETE
242246
response = request.create_response()
247+
response.result.deserialize()
243248
assert response.result.is_final is True
244249
assert response.result.is_sequence_final is True
245250

0 commit comments

Comments
 (0)