Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
76de3b1
unify new_tokens format
netanel-haber Jun 5, 2025
34e0a0b
sort out mixed_sampler
netanel-haber Jun 6, 2025
ecd502d
refactor torch sampler
netanel-haber Jun 8, 2025
8116a93
Merge branch 'main' into user/nhaber/feature/align_sample_state_with_…
netanel-haber Jun 8, 2025
bb42c25
Merge branch 'main' into user/nhaber/feature/align_sample_state_with_…
netanel-haber Jun 9, 2025
a519c6b
eagle3?
netanel-haber Jun 8, 2025
1e058e8
fix mtp
netanel-haber Jun 9, 2025
751f001
all_requests()
netanel-haber Jun 9, 2025
cf00142
eagle???
netanel-haber Jun 9, 2025
208af17
sanity
netanel-haber Jun 10, 2025
1822f05
revert eagle
netanel-haber Jun 10, 2025
f8861a2
it works
netanel-haber Jun 10, 2025
151b565
sort
netanel-haber Jun 10, 2025
a6102ff
still works
netanel-haber Jun 10, 2025
860f76a
still works
netanel-haber Jun 10, 2025
57a42b0
test works now
netanel-haber Jun 10, 2025
f705a90
like an onion
netanel-haber Jun 10, 2025
abd06c1
still works
netanel-haber Jun 10, 2025
01d3c28
Eagle3 actually works?
netanel-haber Jun 10, 2025
92bb1d7
Merge branch 'main' into user/nhaber/feature/align_sample_state_with_…
netanel-haber Jun 10, 2025
effd59e
Update tensorrt_llm/_torch/pyexecutor/scheduler.py
netanel-haber Jun 11, 2025
26f4d83
Merge branch 'main' into user/nhaber/feature/align_sample_state_with_…
netanel-haber Jun 11, 2025
2743423
restore instantiate_sampler logic
netanel-haber Jun 11, 2025
0281ce7
import
netanel-haber Jun 11, 2025
a77a3a9
revert is_mtp
netanel-haber Jun 11, 2025
50f8ad8
fixes
netanel-haber Jun 11, 2025
a77ddf1
so be it
netanel-haber Jun 11, 2025
8c9f17c
minimize diff
netanel-haber Jun 11, 2025
debc5f7
clarity
netanel-haber Jun 11, 2025
9a9b3bb
codespell
netanel-haber Jun 11, 2025
dbc10f5
fix EarlyStopSampler
netanel-haber Jun 11, 2025
5d25ff9
add seq_slot to disagg requests
netanel-haber Jun 11, 2025
fb20674
Merge branch 'main' into user/nhaber/feature/align_sample_state_with_…
netanel-haber Jun 12, 2025
1be314b
always assign a seq_slot unless empty
netanel-haber Jun 12, 2025
b5f71e1
don't sort previous_batch_indices?
netanel-haber Jun 15, 2025
8ffb672
Merge branch 'main' into user/nhaber/feature/align_sample_state_with_…
netanel-haber Jun 15, 2025
8acf382
Merge branch 'main' into user/nhaber/feature/align_sample_state_with_…
netanel-haber Jun 16, 2025
648003b
tensor is not None
netanel-haber Jun 16, 2025
e25ec28
fix log probs
netanel-haber Jun 16, 2025
692b859
cleanliness
netanel-haber Jun 16, 2025
7f44c16
trivial
netanel-haber Jun 17, 2025
4c72403
solve MTP
netanel-haber Jun 17, 2025
f8fd75d
Merge commit 'e607768e451d34c797fed548ba6372638f280eba' into user/nha…
netanel-haber Jun 17, 2025
0e4446c
Merge branch 'main' into user/nhaber/feature/align_sample_state_with_…
netanel-haber Jun 17, 2025
8ff9511
double check that a request is context before checking if it is chunk…
netanel-haber Jun 18, 2025
3df266a
fix log probs
netanel-haber Jun 18, 2025
dfe4359
Merge branch 'main' into user/nhaber/feature/align_sample_state_with_…
netanel-haber Jun 18, 2025
2390683
yapf and isort conflict
netanel-haber Jun 18, 2025
9deba39
dup
netanel-haber Jun 18, 2025
22eaab9
dup
netanel-haber Jun 18, 2025
1114b2c
["decoder_state"]
netanel-haber Jun 19, 2025
d5313e2
Merge branch 'main' into user/nhaber/feature/align_sample_state_with_…
netanel-haber Jun 19, 2025
dd8a660
Merge branch 'main' into user/nhaber/feature/align_sample_state_with_…
netanel-haber Jun 19, 2025
f3914f7
can't freeze SampleStateTensors because of .logits is assigned to in …
netanel-haber Jun 19, 2025
feb6c17
unfreeze classes
netanel-haber Jun 19, 2025
9f98a0b
max_num_sequences is the number of seq_slots, not max_batch_size, swi…
netanel-haber Jun 19, 2025
cdec975
fix ad_deploy with torchsampler, and add a seq_slot manager
netanel-haber Jun 19, 2025
e7ad537
get rid of TorchStarAttentionSampler - it wasn't doing anything inter…
netanel-haber Jun 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from torch._prims_common import DeviceLikeType

from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
from tensorrt_llm._utils import nvtx_range

from ...._utils import mpi_rank, mpi_world_size
Expand All @@ -12,6 +13,7 @@
from ....llmapi.llm_args import _AutoDeployLlmArgs
from ....mapping import Mapping
from ...distributed import MPIDist
from ...pyexecutor._util import create_torch_sampler_args
from ...pyexecutor.config import PyTorchConfig
from ...pyexecutor.model_engine import ModelEngine
from ...pyexecutor.py_executor import PyExecutor
Expand Down Expand Up @@ -292,7 +294,13 @@ def create_autodeploy_executor(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
)
resource_manager = ResourceManager({ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager})
seq_slot_manager = SeqSlotManager(max_num_sequences=max_batch_size * dist_mapping.pp_size)
resource_manager = ResourceManager(
{
ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager,
ResourceManagerType.SEQ_SLOT_MANAGER: seq_slot_manager,
}
)
resource_manager.resource_managers.move_to_end(ResourceManagerType.KV_CACHE_MANAGER, last=True)

# scheduling
Expand All @@ -303,15 +311,17 @@ def create_autodeploy_executor(
scheduler = SimpleScheduler(capacitor_scheduler, mb_scheduler)

# search sampler with speculative decoding
sampler = TorchSampler(max_seq_len=max_seq_len)

# creating the executor object
sampler_args = create_torch_sampler_args(
executor_config, dist_mapping, mixed_sampler=False, max_seq_len=max_seq_len
)
sampler = TorchSampler(sampler_args)
py_executor = PyExecutor(
resource_manager,
scheduler,
model_engine=engine,
sampler=sampler,
dist=mpi_dist,
max_num_sequences=ad_config.max_batch_size * dist_mapping.pp_size,
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
max_input_len=ad_config.max_input_len,
max_batch_size=ad_config.max_batch_size,
Expand Down
51 changes: 32 additions & 19 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@
from .resource_manager import (KVCacheManager, MambaHybridCacheManager,
PeftCacheManager, ResourceManager,
ResourceManagerType)
from .sampler import (EarlyStopSampler, TorchSampler, TorchStarAttentionSampler,
TRTLLMSampler)
from .sampler import EarlyStopSampler, TorchSampler, TRTLLMSampler
from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler,
SimpleScheduler)
from .seq_slot_manager import SeqSlotManager
Expand Down Expand Up @@ -506,6 +505,7 @@ def create_py_executor_instance(
model_engine=model_engine,
sampler=sampler,
dist=dist,
max_num_sequences=max_num_sequences,
disable_overlap_scheduler=pytorch_backend_config.
disable_overlap_scheduler,
max_batch_size=executor_config.max_batch_size,
Expand All @@ -517,31 +517,44 @@ def create_py_executor_instance(
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold)


def instantiate_sampler(model_engine: PyTorchModelEngine,
def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping,
*, max_seq_len: int, mixed_sampler: bool):
max_num_sequences = executor_config.max_batch_size * mapping.pp_size
max_draft_tokens = (0 if executor_config.speculative_config is None else
executor_config.speculative_config.max_draft_tokens)
return TorchSampler.Args(
max_seq_len=max_seq_len,
max_draft_tokens=max_draft_tokens,
max_num_sequences=max_num_sequences,
max_beam_width=executor_config.max_beam_width,
mixed_sampler=mixed_sampler,
)


def instantiate_sampler(engine: PyTorchModelEngine,
executor_config: ExecutorConfig,
pytorch_backend_config: PyTorchConfig,
mapping: Mapping):
sampler_args = create_torch_sampler_args(
executor_config,
mapping,
max_seq_len=engine.max_seq_len,
mixed_sampler=pytorch_backend_config.mixed_sampler)
if mapping.cp_config.get('cp_type') == 'star_attention':
assert pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION", "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'"
sampler = TorchStarAttentionSampler(
max_seq_len=model_engine.max_seq_len)
elif model_engine.spec_config is not None and model_engine.spec_config.spec_dec_mode.has_spec_decoder(
return TorchSampler(sampler_args)
if engine.spec_config is not None and engine.spec_config.spec_dec_mode.has_spec_decoder(
):
sampler = get_spec_decoder(max_seq_len=model_engine.max_seq_len,
spec_config=model_engine.spec_config)
elif pytorch_backend_config.enable_trtllm_sampler:
return get_spec_decoder(sampler_args, engine.spec_config)
if pytorch_backend_config.enable_trtllm_sampler:
decoding_mode = get_decoding_mode(executor_config)
sampler = TRTLLMSampler(
executor_config, model_engine.model, model_engine.dtype, mapping,
decoding_mode, pytorch_backend_config.disable_overlap_scheduler)
elif not model_engine.model.model_config.is_generation:
return TRTLLMSampler(executor_config, engine.model, engine.dtype,
mapping, decoding_mode,
pytorch_backend_config.disable_overlap_scheduler)
if not engine.model.model_config.is_generation:
# NOTE: choose sampler based on model type
sampler = EarlyStopSampler()
else:
sampler = TorchSampler(
max_seq_len=model_engine.max_seq_len,
mixed_sampler=pytorch_backend_config.mixed_sampler)
return sampler
return EarlyStopSampler()
return TorchSampler(sampler_args)


def get_decoding_mode(executor_config: ExecutorConfig) -> DecodingMode:
Expand Down
8 changes: 2 additions & 6 deletions tensorrt_llm/_torch/pyexecutor/guided_decoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
import math
from typing import List, Optional

Expand Down Expand Up @@ -52,8 +51,7 @@ def bitmask_size(self) -> int:

def build(self, scheduled_requests: ScheduledRequests,
resource_manager: SeqSlotManager) -> None:
for llm_req in itertools.chain(scheduled_requests.context_requests,
scheduled_requests.generation_requests):
for llm_req in scheduled_requests.all_requests():
if llm_req.guided_decoding_params is None:
continue
slot = resource_manager.slot_manager.get_slot(llm_req.request_id)
Expand Down Expand Up @@ -84,9 +82,7 @@ def execute(self, scheduled_requests: ScheduledRequests,
torch.cuda.current_stream().wait_stream(self._stream)

batched_logits, batched_bitmask = [], []
for i, llm_req in enumerate(
itertools.chain(scheduled_requests.context_requests,
scheduled_requests.generation_requests)):
for i, llm_req in enumerate(scheduled_requests.all_requests()):
if llm_req.guided_decoding_params is None:
continue
if llm_req.is_context_init_state and not llm_req.is_last_context_chunk:
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def __init__(
return_logits_device_memory: bool = True,
exclude_last_generation_logits: bool = False,
stop_words_list: list[list[int]] | None = None,
is_draft: bool = False,
**kwargs):
self.py_logits_post_processors = kwargs.pop("py_logits_post_processors",
None)
Expand Down Expand Up @@ -288,6 +289,7 @@ def __init__(
self.py_return_context_logits = return_context_logits
self.py_return_generation_logits = return_generation_logits
self.py_return_logits_device_memory = return_logits_device_memory
self.py_is_draft = is_draft

# TODO: remove this when use DynamicDecodeOp in pytorch flow.
# currently, keep py_stop_words_list as python list, rather than tensor.
Expand Down
Loading