Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,7 @@ def _executor_loop(self):
self.resource_manager.prepare_resources(scheduled_batch)
if self.drafter is not None:
self.drafter.prepare_draft_tokens(
scheduled_batch, self.resource_manager)
scheduled_batch, self.resource_manager, iter_stats)

batch_outputs = self._forward_step(scheduled_batch)

Expand Down
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/speculative/drafter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import ABC, abstractmethod
from typing import Optional

from tensorrt_llm.bindings.executor import IterationStats

from ..pyexecutor.resource_manager import ResourceManager
from ..pyexecutor.scheduler import ScheduledRequests

Expand All @@ -13,6 +15,7 @@ def prepare_draft_tokens(
self,
scheduled_requests: ScheduledRequests,
resource_manager: Optional[ResourceManager] = None,
iter_stats: IterationStats = None,
) -> None:
"""
Prepare the drafter tokens for the forward computation this step.
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/speculative/model_drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

from tensorrt_llm._utils import nvtx_range
from tensorrt_llm.bindings.executor import IterationStats
from tensorrt_llm.logger import logger

from ..pyexecutor.llm_request import LlmRequest, LlmRequestState, SamplingConfig
Expand Down Expand Up @@ -297,6 +298,7 @@ def prepare_draft_tokens(
self,
scheduled_requests: ScheduledRequests,
resource_manager: Optional[ResourceManager] = None,
iter_stats: IterationStats = None,
) -> None:
"""
Prepare draft tokens for the scheduled requests.
Expand Down
62 changes: 61 additions & 1 deletion tensorrt_llm/_torch/speculative/ngram.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import time
from itertools import chain
from typing import Dict, Tuple

from ordered_set import OrderedSet

from tensorrt_llm.bindings.executor import IterationStats
from tensorrt_llm.logger import logger

from ..pyexecutor.llm_request import *
Expand Down Expand Up @@ -174,6 +177,21 @@ def prepare_draft_tokens(
self,
scheduled_requests: ScheduledRequests,
resource_manager: Optional[ResourceManager] = None,
iter_stats: IterationStats = None,
) -> None:

if iter_stats is not None:
start_time = time.time()

self._prepare_draft_tokens(scheduled_requests)

if iter_stats is not None:
self._update_ngram_iter_stats(scheduled_requests, iter_stats,
start_time)

def _prepare_draft_tokens(
self,
scheduled_requests: ScheduledRequests,
) -> None:
# Sort by request_id when py_batch_idx is None as a fallback.
# This happens in the disagg case: for a set of new requests, we draft
Expand All @@ -195,6 +213,48 @@ def prepare_draft_tokens(
)
# Pad length to `self.max_draft_len`
if len(draft_tokens) > 0:
pad_length = self.max_draft_len - len(draft_tokens)
draft_length = len(draft_tokens)
pad_length = self.max_draft_len - draft_length
draft_tokens.extend([request.py_end_id] * pad_length)
else:
draft_length = 0
request.py_draft_tokens = draft_tokens
request.py_draft_length = draft_length

def _update_ngram_iter_stats(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this function can be moved to Drafter class.

self,
scheduled_requests: ScheduledRequests,
iter_stats: IterationStats,
start_time: float,
) -> Tuple[ScheduledRequests, Dict[int, LlmRequest]]:
"""
Get statistic information from the draft tokens in NGram drafter
"""
now_time = time.time()

total_num_draft_tokens = 0
total_num_accepted_tokens = 0
num_requests_with_draft_tokens = 0
for request in scheduled_requests.generation_requests:
if request.py_last_draft_tokens is not None:
total_num_draft_tokens += request.py_draft_length
total_num_accepted_tokens += request.py_num_accepted_draft_tokens
num_requests_with_draft_tokens += 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this number will always be the scheduled request for one iteration.
If num_requests_with_draft_tokens should exclude the request with padded tokens only, them maybe we should count requests with py_draft_length > 0.


if num_requests_with_draft_tokens > 0:
iter_stats.specdec_stats.iter_latency_ms = (now_time -
start_time) * 1e3
iter_stats.specdec_stats.num_draft_tokens = total_num_draft_tokens
iter_stats.specdec_stats.num_accepted_tokens = total_num_accepted_tokens
iter_stats.specdec_stats.num_requests_with_draft_tokens = num_requests_with_draft_tokens
iter_stats.specdec_stats.acceptance_length = (
total_num_accepted_tokens +
num_requests_with_draft_tokens) / num_requests_with_draft_tokens
else:
iter_stats.specdec_stats.iter_latency_ms = 0.0
iter_stats.specdec_stats.num_draft_tokens = 0
iter_stats.specdec_stats.num_accepted_tokens = 0
iter_stats.specdec_stats.num_requests_with_draft_tokens = 0
iter_stats.specdec_stats.acceptance_length = 1.0

return
Comment on lines +224 to +260
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix return type annotation and approve statistics collection logic.

The statistics collection implementation is correct and properly uses py_draft_length to exclude padding. However, the return type annotation doesn't match the actual return value.

Apply this diff to fix the return type annotation:

-    ) -> Tuple[ScheduledRequests, Dict[int, LlmRequest]]:
+    ) -> None:

The rest of the implementation looks good - the acceptance length calculation and latency measurement are correct, and the conditional handling for cases with/without draft tokens is appropriate.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _update_ngram_iter_stats(
self,
scheduled_requests: ScheduledRequests,
iter_stats: IterationStats,
start_time: float,
) -> Tuple[ScheduledRequests, Dict[int, LlmRequest]]:
"""
Get statistic information from the draft tokens in NGram drafter
"""
now_time = time.time()
total_num_draft_tokens = 0
total_num_accepted_tokens = 0
num_requests_with_draft_tokens = 0
for request in scheduled_requests.generation_requests:
if request.py_last_draft_tokens is not None:
total_num_draft_tokens += request.py_draft_length
total_num_accepted_tokens += request.py_num_accepted_draft_tokens
num_requests_with_draft_tokens += 1
if num_requests_with_draft_tokens > 0:
iter_stats.specdec_stats.iter_latency_ms = (now_time -
start_time) * 1e3
iter_stats.specdec_stats.num_draft_tokens = total_num_draft_tokens
iter_stats.specdec_stats.num_accepted_tokens = total_num_accepted_tokens
iter_stats.specdec_stats.num_requests_with_draft_tokens = num_requests_with_draft_tokens
iter_stats.specdec_stats.acceptance_length = (
total_num_accepted_tokens +
num_requests_with_draft_tokens) / num_requests_with_draft_tokens
else:
iter_stats.specdec_stats.iter_latency_ms = 0.0
iter_stats.specdec_stats.num_draft_tokens = 0
iter_stats.specdec_stats.num_accepted_tokens = 0
iter_stats.specdec_stats.num_requests_with_draft_tokens = 0
iter_stats.specdec_stats.acceptance_length = 1.0
return
def _update_ngram_iter_stats(
self,
scheduled_requests: ScheduledRequests,
iter_stats: IterationStats,
start_time: float,
) -> None:
🧰 Tools
🪛 Ruff (0.12.2)

229-229: LlmRequest may be undefined, or defined from star imports

(F405)

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/speculative/ngram.py lines 224 to 260, the method
_update_ngram_iter_stats currently has a return type annotation of
Tuple[ScheduledRequests, Dict[int, LlmRequest]] but does not return any value.
To fix this, update the return type annotation to None to match the actual
behavior of the method. The statistics collection logic is correct and does not
require changes.