Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 2 additions & 22 deletions tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch
from torch._prims_common import DeviceLikeType

from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
from tensorrt_llm._utils import nvtx_range

from ...._utils import mpi_rank, mpi_world_size
Expand Down Expand Up @@ -265,7 +264,6 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
ad_config: _AutoDeployLlmArgs = executor_config.pytorch_backend_config

max_batch_size = ad_config.max_batch_size
max_num_sequences = ad_config.max_batch_size * dist_mapping.pp_size
max_seq_len = ad_config.max_seq_len
attn_page_size = ad_config.attn_page_size
max_num_tokens = ad_config.max_num_tokens
Expand Down Expand Up @@ -296,13 +294,7 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
)
seq_slot_manager = SeqSlotManager(max_num_sequences=max_batch_size * dist_mapping.pp_size)
resource_manager = ResourceManager(
{
ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager,
ResourceManagerType.SEQ_SLOT_MANAGER: seq_slot_manager,
}
)
resource_manager = ResourceManager({ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager})
resource_manager.resource_managers.move_to_end(ResourceManagerType.KV_CACHE_MANAGER, last=True)

# scheduling
Expand All @@ -313,18 +305,7 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
scheduler = SimpleScheduler(capacitor_scheduler, mb_scheduler)

# search sampler with speculative decoding
# TODO (lucaslie, fridah-nv): some models require mixed_sampler=True to have good outputs, see
# https://github.com/NVIDIA/TensorRT-LLM/issues/5254
# We should expose mixed_sample to our build_and_run_ad script so we can configure this
# correctly for models as needed.
sampler_args = TorchSampler.Args(
max_seq_len=max_seq_len,
max_draft_tokens=max_draft_tokens,
max_num_sequences=max_num_sequences,
max_beam_width=executor_config.max_beam_width,
mixed_sampler=ad_config.mixed_sampler,
)
sampler = TorchSampler(sampler_args)
sampler = TorchSampler(max_seq_len=max_seq_len)

# creating the executor object
py_executor = PyExecutor(
Expand All @@ -333,7 +314,6 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
model_engine=engine,
sampler=sampler,
dist=mpi_dist,
max_num_sequences=max_num_sequences,
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
max_input_len=ad_config.max_input_len,
max_batch_size=max_batch_size,
Expand Down
51 changes: 19 additions & 32 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from .resource_manager import (KVCacheManager, MambaHybridCacheManager,
PeftCacheManager, ResourceManager,
ResourceManagerType)
from .sampler import EarlyStopSampler, TorchSampler, TRTLLMSampler
from .sampler import (EarlyStopSampler, TorchSampler, TorchStarAttentionSampler,
TRTLLMSampler)
from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler,
SimpleScheduler)
from .seq_slot_manager import SeqSlotManager
Expand Down Expand Up @@ -511,7 +512,6 @@ def create_py_executor_instance(
model_engine=model_engine,
sampler=sampler,
dist=dist,
max_num_sequences=max_num_sequences,
disable_overlap_scheduler=pytorch_backend_config.
disable_overlap_scheduler,
max_batch_size=executor_config.max_batch_size,
Expand All @@ -523,44 +523,31 @@ def create_py_executor_instance(
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold)


def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping,
*, max_seq_len: int, mixed_sampler: bool):
max_num_sequences = executor_config.max_batch_size * mapping.pp_size
max_draft_tokens = (0 if executor_config.speculative_config is None else
executor_config.speculative_config.max_draft_tokens)
return TorchSampler.Args(
max_seq_len=max_seq_len,
max_draft_tokens=max_draft_tokens,
max_num_sequences=max_num_sequences,
max_beam_width=executor_config.max_beam_width,
mixed_sampler=mixed_sampler,
)


def instantiate_sampler(engine: PyTorchModelEngine,
def instantiate_sampler(model_engine: PyTorchModelEngine,
executor_config: ExecutorConfig,
pytorch_backend_config: PyTorchConfig,
mapping: Mapping):
sampler_args = create_torch_sampler_args(
executor_config,
mapping,
max_seq_len=engine.max_seq_len,
mixed_sampler=pytorch_backend_config.mixed_sampler)
if mapping.cp_config.get('cp_type') == 'star_attention':
assert pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION", "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'"
return TorchSampler(sampler_args)
if engine.spec_config is not None and engine.spec_config.spec_dec_mode.has_spec_decoder(
sampler = TorchStarAttentionSampler(
max_seq_len=model_engine.max_seq_len)
elif model_engine.spec_config is not None and model_engine.spec_config.spec_dec_mode.has_spec_decoder(
):
return get_spec_decoder(sampler_args, engine.spec_config)
if pytorch_backend_config.enable_trtllm_sampler:
sampler = get_spec_decoder(max_seq_len=model_engine.max_seq_len,
spec_config=model_engine.spec_config)
elif pytorch_backend_config.enable_trtllm_sampler:
decoding_mode = get_decoding_mode(executor_config)
return TRTLLMSampler(executor_config, engine.model, engine.dtype,
mapping, decoding_mode,
pytorch_backend_config.disable_overlap_scheduler)
if not engine.model.model_config.is_generation:
sampler = TRTLLMSampler(
executor_config, model_engine.model, model_engine.dtype, mapping,
decoding_mode, pytorch_backend_config.disable_overlap_scheduler)
elif not model_engine.model.model_config.is_generation:
# NOTE: choose sampler based on model type
return EarlyStopSampler()
return TorchSampler(sampler_args)
sampler = EarlyStopSampler()
else:
sampler = TorchSampler(
max_seq_len=model_engine.max_seq_len,
mixed_sampler=pytorch_backend_config.mixed_sampler)
return sampler


def get_decoding_mode(executor_config: ExecutorConfig) -> DecodingMode:
Expand Down
8 changes: 6 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/guided_decoder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import math
from typing import List, Optional

Expand Down Expand Up @@ -51,7 +52,8 @@ def bitmask_size(self) -> int:

def build(self, scheduled_requests: ScheduledRequests,
resource_manager: SeqSlotManager) -> None:
for llm_req in scheduled_requests.all_requests():
for llm_req in itertools.chain(scheduled_requests.context_requests,
scheduled_requests.generation_requests):
if llm_req.guided_decoding_params is None:
continue
slot = resource_manager.slot_manager.get_slot(llm_req.request_id)
Expand Down Expand Up @@ -82,7 +84,9 @@ def execute(self, scheduled_requests: ScheduledRequests,
torch.cuda.current_stream().wait_stream(self._stream)

batched_logits, batched_bitmask = [], []
for i, llm_req in enumerate(scheduled_requests.all_requests()):
for i, llm_req in enumerate(
itertools.chain(scheduled_requests.context_requests,
scheduled_requests.generation_requests)):
if llm_req.guided_decoding_params is None:
continue
if llm_req.is_context_init_state and not llm_req.is_last_context_chunk:
Expand Down
2 changes: 0 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ def __init__(
return_logits_device_memory: bool = True,
exclude_last_generation_logits: bool = False,
stop_words_list: list[list[int]] | None = None,
is_draft: bool = False,
**kwargs):
self.py_logits_post_processors = kwargs.pop("py_logits_post_processors",
None)
Expand Down Expand Up @@ -287,7 +286,6 @@ def __init__(
self.py_return_context_logits = return_context_logits
self.py_return_generation_logits = return_generation_logits
self.py_return_logits_device_memory = return_logits_device_memory
self.py_is_draft = is_draft

# TODO: remove this when use DynamicDecodeOp in pytorch flow.
# currently, keep py_stop_words_list as python list, rather than tensor.
Expand Down
Loading