diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index a733f37b1b2..840241641a2 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -4,7 +4,6 @@ import torch from torch._prims_common import DeviceLikeType -from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager from tensorrt_llm._utils import nvtx_range from ...._utils import mpi_rank, mpi_world_size @@ -265,7 +264,6 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir: ad_config: _AutoDeployLlmArgs = executor_config.pytorch_backend_config max_batch_size = ad_config.max_batch_size - max_num_sequences = ad_config.max_batch_size * dist_mapping.pp_size max_seq_len = ad_config.max_seq_len attn_page_size = ad_config.attn_page_size max_num_tokens = ad_config.max_num_tokens @@ -296,13 +294,7 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir: max_seq_len=max_seq_len, max_batch_size=max_batch_size, ) - seq_slot_manager = SeqSlotManager(max_num_sequences=max_batch_size * dist_mapping.pp_size) - resource_manager = ResourceManager( - { - ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager, - ResourceManagerType.SEQ_SLOT_MANAGER: seq_slot_manager, - } - ) + resource_manager = ResourceManager({ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager}) resource_manager.resource_managers.move_to_end(ResourceManagerType.KV_CACHE_MANAGER, last=True) # scheduling @@ -313,18 +305,7 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir: scheduler = SimpleScheduler(capacitor_scheduler, mb_scheduler) # search sampler with speculative decoding - # TODO (lucaslie, fridah-nv): some models require mixed_sampler=True to have good outputs, see - # https://github.com/NVIDIA/TensorRT-LLM/issues/5254 - # We should expose mixed_sample to our build_and_run_ad script so we can configure this - # correctly for models as needed. - sampler_args = TorchSampler.Args( - max_seq_len=max_seq_len, - max_draft_tokens=max_draft_tokens, - max_num_sequences=max_num_sequences, - max_beam_width=executor_config.max_beam_width, - mixed_sampler=ad_config.mixed_sampler, - ) - sampler = TorchSampler(sampler_args) + sampler = TorchSampler(max_seq_len=max_seq_len) # creating the executor object py_executor = PyExecutor( @@ -333,7 +314,6 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir: model_engine=engine, sampler=sampler, dist=mpi_dist, - max_num_sequences=max_num_sequences, disable_overlap_scheduler=ad_config.disable_overlap_scheduler, max_input_len=ad_config.max_input_len, max_batch_size=max_batch_size, diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 6306afc1ccc..4000f39329f 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -26,7 +26,8 @@ from .resource_manager import (KVCacheManager, MambaHybridCacheManager, PeftCacheManager, ResourceManager, ResourceManagerType) -from .sampler import EarlyStopSampler, TorchSampler, TRTLLMSampler +from .sampler import (EarlyStopSampler, TorchSampler, TorchStarAttentionSampler, + TRTLLMSampler) from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler, SimpleScheduler) from .seq_slot_manager import SeqSlotManager @@ -511,7 +512,6 @@ def create_py_executor_instance( model_engine=model_engine, sampler=sampler, dist=dist, - max_num_sequences=max_num_sequences, disable_overlap_scheduler=pytorch_backend_config. disable_overlap_scheduler, max_batch_size=executor_config.max_batch_size, @@ -523,44 +523,31 @@ def create_py_executor_instance( garbage_collection_gen0_threshold=garbage_collection_gen0_threshold) -def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping, - *, max_seq_len: int, mixed_sampler: bool): - max_num_sequences = executor_config.max_batch_size * mapping.pp_size - max_draft_tokens = (0 if executor_config.speculative_config is None else - executor_config.speculative_config.max_draft_tokens) - return TorchSampler.Args( - max_seq_len=max_seq_len, - max_draft_tokens=max_draft_tokens, - max_num_sequences=max_num_sequences, - max_beam_width=executor_config.max_beam_width, - mixed_sampler=mixed_sampler, - ) - - -def instantiate_sampler(engine: PyTorchModelEngine, +def instantiate_sampler(model_engine: PyTorchModelEngine, executor_config: ExecutorConfig, pytorch_backend_config: PyTorchConfig, mapping: Mapping): - sampler_args = create_torch_sampler_args( - executor_config, - mapping, - max_seq_len=engine.max_seq_len, - mixed_sampler=pytorch_backend_config.mixed_sampler) if mapping.cp_config.get('cp_type') == 'star_attention': assert pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION", "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'" - return TorchSampler(sampler_args) - if engine.spec_config is not None and engine.spec_config.spec_dec_mode.has_spec_decoder( + sampler = TorchStarAttentionSampler( + max_seq_len=model_engine.max_seq_len) + elif model_engine.spec_config is not None and model_engine.spec_config.spec_dec_mode.has_spec_decoder( ): - return get_spec_decoder(sampler_args, engine.spec_config) - if pytorch_backend_config.enable_trtllm_sampler: + sampler = get_spec_decoder(max_seq_len=model_engine.max_seq_len, + spec_config=model_engine.spec_config) + elif pytorch_backend_config.enable_trtllm_sampler: decoding_mode = get_decoding_mode(executor_config) - return TRTLLMSampler(executor_config, engine.model, engine.dtype, - mapping, decoding_mode, - pytorch_backend_config.disable_overlap_scheduler) - if not engine.model.model_config.is_generation: + sampler = TRTLLMSampler( + executor_config, model_engine.model, model_engine.dtype, mapping, + decoding_mode, pytorch_backend_config.disable_overlap_scheduler) + elif not model_engine.model.model_config.is_generation: # NOTE: choose sampler based on model type - return EarlyStopSampler() - return TorchSampler(sampler_args) + sampler = EarlyStopSampler() + else: + sampler = TorchSampler( + max_seq_len=model_engine.max_seq_len, + mixed_sampler=pytorch_backend_config.mixed_sampler) + return sampler def get_decoding_mode(executor_config: ExecutorConfig) -> DecodingMode: diff --git a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py index 756c177a6ea..fc21a2096e2 100644 --- a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py +++ b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py @@ -1,3 +1,4 @@ +import itertools import math from typing import List, Optional @@ -51,7 +52,8 @@ def bitmask_size(self) -> int: def build(self, scheduled_requests: ScheduledRequests, resource_manager: SeqSlotManager) -> None: - for llm_req in scheduled_requests.all_requests(): + for llm_req in itertools.chain(scheduled_requests.context_requests, + scheduled_requests.generation_requests): if llm_req.guided_decoding_params is None: continue slot = resource_manager.slot_manager.get_slot(llm_req.request_id) @@ -82,7 +84,9 @@ def execute(self, scheduled_requests: ScheduledRequests, torch.cuda.current_stream().wait_stream(self._stream) batched_logits, batched_bitmask = [], [] - for i, llm_req in enumerate(scheduled_requests.all_requests()): + for i, llm_req in enumerate( + itertools.chain(scheduled_requests.context_requests, + scheduled_requests.generation_requests)): if llm_req.guided_decoding_params is None: continue if llm_req.is_context_init_state and not llm_req.is_last_context_chunk: diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index f16e4e2dcfa..01e9324e987 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -253,7 +253,6 @@ def __init__( return_logits_device_memory: bool = True, exclude_last_generation_logits: bool = False, stop_words_list: list[list[int]] | None = None, - is_draft: bool = False, **kwargs): self.py_logits_post_processors = kwargs.pop("py_logits_post_processors", None) @@ -287,7 +286,6 @@ def __init__( self.py_return_context_logits = return_context_logits self.py_return_generation_logits = return_generation_logits self.py_return_logits_device_memory = return_logits_device_memory - self.py_is_draft = is_draft # TODO: remove this when use DynamicDecodeOp in pytorch flow. # currently, keep py_stop_words_list as python list, rather than tensor. diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index e877117ec85..3a454cb740f 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -4,6 +4,7 @@ import gc import glob import inspect +import itertools import math import multiprocessing import os @@ -20,7 +21,6 @@ import torch._dynamo.config import tensorrt_llm.bindings.internal.userbuffers as ub -from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest from tensorrt_llm._torch.pyexecutor.sampler import SampleStateTensors from tensorrt_llm._torch.speculative.mtp import SampleStateTensorsMTP from tensorrt_llm._utils import (is_trace_enabled, local_mpi_rank, @@ -319,7 +319,6 @@ def _filter_cuda_graph_batch_sizes(cuda_graph_batch_sizes: list[int], class PyTorchModelEngine(ModelEngine): - BEAM_WIDTH = 1 def __init__( self, @@ -660,12 +659,13 @@ def get_autotune_warmup_request(): return result @contextlib.contextmanager - def release_batch(result: ScheduledRequests | None): + def release_batch(result): try: yield result finally: if result is not None: - for req in result.all_requests(): + for req in itertools.chain(result.generation_requests, + result.context_requests): kv_cache_manager.free_resources(req) if spec_resource_manager is not None: spec_resource_manager.free_resources(req) @@ -1152,15 +1152,7 @@ def _prepare_tp_inputs( draft_lens = [] mrope_config = defaultdict(list) - mtp_batch_idx = 0 # Temporary: MTP (and Eagle3OneModel) remain the only samplers to index new_tokens serially - - def py_batch_idx(request: LlmRequest) -> int: - if not self.without_logits: - return request.seq_slot - nonlocal mtp_batch_idx - batch_idx = mtp_batch_idx - mtp_batch_idx += 1 - return batch_idx + batch_idx = 0 for request in scheduled_requests.context_requests: request_ids.append(request.py_request_id) @@ -1191,9 +1183,10 @@ def py_batch_idx(request: LlmRequest) -> int: ) if mrope_rotary_cos_sin.device == 'cpu' else mrope_rotary_cos_sin mrope_config['mrope_rotary_cos_sin'].append( mrope_rotary_cos_sin.to('cuda', non_blocking=True)) - request.py_batch_idx = py_batch_idx(request) + request.py_batch_idx = batch_idx + batch_idx += 1 - num_ctx_requests = len(scheduled_requests.context_requests) + num_ctx_requests = batch_idx num_ctx_tokens = len(input_ids) new_tokens_device, new_tokens_lens_device, next_draft_tokens_device = None, None, None if new_tensors_device is not None: @@ -1233,7 +1226,7 @@ def py_batch_idx(request: LlmRequest) -> int: assert spec_dec_mode.support_overlap_scheduler( ), f"{self.spec_config.spec_dec_name} does not support overlap scheduler" - # will contain previous batch indices of generation requests + # will contain previous batch incices of generation requests previous_batch_indices = [] previous_pos_indices = [] for request in extend_requests: @@ -1273,11 +1266,13 @@ def py_batch_idx(request: LlmRequest) -> int: num_cached_tokens_per_seq.append(past_seen_token_num) request_ids.append(request.py_request_id) # update batch index - request.py_batch_idx = py_batch_idx(request) + request.py_batch_idx = batch_idx + batch_idx += 1 else: # update batch index previous_batch_idx = request.py_batch_idx - request.py_batch_idx = py_batch_idx(request) + request.py_batch_idx = batch_idx + batch_idx += 1 # inputs # overlap scheduler can only support the speculative decoding # methods with a fixed number of draft tokens @@ -1328,21 +1323,12 @@ def py_batch_idx(request: LlmRequest) -> int: prompt_lengths.append(request.py_prompt_len) draft_lens.append(0) - request.py_batch_idx = py_batch_idx(request) - - previous_batch_len = len(previous_batch_indices) - - def previous_seq_slots_device(): - previous_batch_indices_host = torch.tensor(previous_batch_indices, - dtype=torch.int, - pin_memory=True) - previous_slots = self.previous_batch_indices_cuda[: - previous_batch_len] - previous_slots.copy_(previous_batch_indices_host, non_blocking=True) - return previous_slots + request.py_batch_idx = batch_idx + batch_idx += 1 num_tokens = len(input_ids) num_draft_tokens = len(draft_tokens) + previous_batchs = len(previous_batch_indices) num_requests = len(request_ids) total_num_tokens = len(position_ids) assert total_num_tokens <= self.max_num_tokens, ( @@ -1360,55 +1346,67 @@ def previous_seq_slots_device(): self.draft_tokens_cuda[:len(draft_tokens)].copy_(draft_tokens, non_blocking=True) if next_draft_tokens_device is not None: - if previous_batch_len > 0: - previous_slots = previous_seq_slots_device() + if len(previous_batch_indices) > 0: + previous_batch_indices = torch.tensor(previous_batch_indices, + dtype=torch.int, + pin_memory=True) + self.previous_batch_indices_cuda[:previous_batchs].copy_( + previous_batch_indices, non_blocking=True) # previous input ids - previous_batch_tokens = previous_batch_len * ( - 1 + self.max_draft_len) - new_tokens = new_tokens_device[previous_slots, :].flatten() - self.input_ids_cuda[num_tokens:num_tokens + - previous_batch_tokens].copy_( - new_tokens, non_blocking=True) + previous_batch_tokens = previous_batchs * (1 + + self.max_draft_len) + self.input_ids_cuda[ + num_tokens:num_tokens + + previous_batch_tokens].copy_(new_tokens_device[ + self.previous_batch_indices_cuda[:previous_batchs], :]. + flatten(), + non_blocking=True) # previous draft tokens - previous_batch_draft_tokens = previous_batch_len * self.max_draft_len - self.draft_tokens_cuda[num_draft_tokens:num_draft_tokens + - previous_batch_draft_tokens].copy_( - next_draft_tokens_device[ - previous_slots, :].flatten(), - non_blocking=True) + previous_batch_draft_tokens = previous_batchs * self.max_draft_len + self.draft_tokens_cuda[ + num_draft_tokens:num_draft_tokens + + previous_batch_draft_tokens].copy_(next_draft_tokens_device[ + self.previous_batch_indices_cuda[:previous_batchs], :]. + flatten(), + non_blocking=True) # prepare data for the preprocess inputs kv_len_offsets_device = new_tokens_lens_device - self.max_draft_len - 1 - previous_pos_indices_host = torch.tensor(previous_pos_indices, - dtype=torch.int, - pin_memory=True) + previous_pos_indices = torch.tensor(previous_pos_indices, + dtype=torch.int, + pin_memory=True) self.previous_pos_indices_cuda[0:previous_batch_tokens].copy_( - previous_pos_indices_host, non_blocking=True) + previous_pos_indices, non_blocking=True) self.previous_pos_id_offsets_cuda[ 0:previous_batch_tokens].copy_( new_tokens_lens_device[self.previous_pos_indices_cuda[ 0:previous_batch_tokens]], non_blocking=True) - self.previous_kv_lens_offsets_cuda[0:previous_batch_len].copy_( - kv_len_offsets_device[previous_slots], non_blocking=True) + self.previous_kv_lens_offsets_cuda[0:previous_batchs].copy_( + kv_len_offsets_device[ + self.previous_batch_indices_cuda[:previous_batchs]], + non_blocking=True) # for the requests that do not have previous batch, set the previous_pos_id_offsets and # previous_kv_lens_offsets to zeros to skip the value changes in _preprocess_inputs self.previous_pos_id_offsets_cuda[ previous_batch_tokens:num_requests * (1 + self.max_draft_len)] *= 0 self.previous_kv_lens_offsets_cuda[ - previous_batch_len:num_requests] *= 0 + previous_batchs:num_requests] *= 0 else: # change the data to zeros to skip the value changes in _preprocess_inputs self.previous_pos_id_offsets_cuda *= 0 self.previous_kv_lens_offsets_cuda *= 0 elif new_tokens_device is not None: - seq_slots_device = previous_seq_slots_device() - max_draft_len = max(draft_lens) - new_tokens = new_tokens_device[:max_draft_len + 1, - seq_slots_device, :self.BEAM_WIDTH] - self.input_ids_cuda[num_tokens:num_tokens + - previous_batch_len].copy_(new_tokens.flatten(), - non_blocking=True) + previous_batch_tokens = len(previous_batch_indices) + previous_batch_indices = torch.tensor(previous_batch_indices, + dtype=torch.int, + pin_memory=True) + self.previous_batch_indices_cuda[:previous_batch_tokens].copy_( + previous_batch_indices, non_blocking=True) + self.input_ids_cuda[num_tokens:num_tokens + previous_batchs].copy_( + new_tokens_device[ + self.previous_batch_indices_cuda[:previous_batchs]], + non_blocking=True) position_ids = torch.tensor(position_ids, dtype=torch.int, @@ -1646,6 +1644,7 @@ def _prepare_star_attention_inputs(self, # for star attention, we need customized block ids block_ids_per_seq = [] num_cached_tokens_per_seq = [] + output_token_idx = 0 for request in scheduled_requests.context_requests: request_ids.append(request.py_request_id) prompt_lengths.append(request.py_prompt_len) @@ -1702,6 +1701,8 @@ def _prepare_star_attention_inputs(self, sequence_lengths.append(len(input_id)) block_ids_per_seq.extend([all_cache_indices]) num_cached_tokens_per_seq.append(past_seen_token_num) + request.output_token_idx = output_token_idx + output_token_idx += 1 num_contexts = len(sequence_lengths) for request in scheduled_requests.context_requests: ctx_iter = request.ctx_iters @@ -1741,6 +1742,8 @@ def _prepare_star_attention_inputs(self, sequence_lengths.append(len(input_id)) block_ids_per_seq.extend([all_cache_indices]) num_cached_tokens_per_seq.append(past_seen_token_num) + request.output_token_idx = output_token_idx + output_token_idx += 1 num_queries = len(sequence_lengths) - num_contexts # Requests with draft tokens are treated like extend requests. @@ -1798,6 +1801,8 @@ def _prepare_star_attention_inputs(self, position_ids.append(last_query_pos_id + request.gen_iters + 1) block_ids_per_seq.extend([all_cache_indices]) num_cached_tokens_per_seq.append(past_seen_token_num) + request.output_token_idx = output_token_idx + output_token_idx += 1 num_tokens = len(input_ids) assert num_tokens <= self.max_num_tokens, ( @@ -2165,7 +2170,9 @@ def _execute_logit_post_processors(self, num_ctx_req = len(scheduled_requests.context_requests) logits_tensor = outputs["logits"] - for idx, request in enumerate(scheduled_requests.all_requests()): + for idx, request in enumerate( + itertools.chain(scheduled_requests.context_requests, + scheduled_requests.generation_requests)): logits_processors = getattr(request, "py_logits_post_processors", None) if not logits_processors: diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index edab51fb7f5..a08680832c7 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -11,12 +11,12 @@ import weakref from collections import namedtuple from contextlib import contextmanager +from itertools import chain from typing import Dict, List, Optional, Tuple, Union import torch from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType -from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager from tensorrt_llm._utils import (customized_gc_thresholds, global_mpi_rank, is_trace_enabled, nvtx_range, trace_func) from tensorrt_llm.bindings.executor import (DisServingRequestStats, @@ -35,7 +35,7 @@ LlmResponse, executor_request_to_llm_request) from .model_engine import ModelEngine from .sampler import Sampler, SampleState, SampleStateTensors, TorchSampler -from .scheduler import RequestScheduler, ScheduledRequests +from .scheduler import ScheduledRequests # Environment variable to specify iteration ranges for profiling start/stop. # Format: "start1-stop1,start2-stop2,..." or single iterations "iter1,iter2,..." @@ -162,11 +162,10 @@ class PyExecutor: def __init__(self, resource_manager, - scheduler: RequestScheduler, + scheduler, model_engine: ModelEngine, sampler: Sampler, dist: Distributed, - max_num_sequences: int, disable_overlap_scheduler: bool = False, max_input_len: int = 2048, max_batch_size: int = 8, @@ -267,13 +266,11 @@ def __init__(self, if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"): self.event_loop = trace_func(self.event_loop) - if self.draft_model_engine is not None: - if self.event_loop.__name__ != self._executor_loop.__name__: - raise NotImplementedError( - "Drafting is not supported for selected executor loop. " - "Please disable disagg/pipeline parallelism/overlap scheduler." - ) - self.draft_seq_slot_manager = SeqSlotManager(max_num_sequences) + if self.draft_model_engine is not None and self.event_loop.__name__ != self._executor_loop.__name__: + raise NotImplementedError( + "Drafting is not supported for selected executor loop. " + "Please disable disagg/pipeline parallelism/overlap scheduler.") + self.garbage_collection_gen0_threshold = garbage_collection_gen0_threshold self.worker_started = False @@ -755,7 +752,7 @@ def _executor_loop_pp(self): "cpu", non_blocking=True) sample_state = self._sample_async( scheduled_batch, batch_outputs) - sample_state.host.logits = logits_host + sample_state.logits_host = logits_host self._update_request_states(scheduled_batch) if self.enable_iter_perf_stats: @@ -786,6 +783,7 @@ def _executor_loop_pp(self): # Receive tokens from previous pp rank (w.r.t model forward direction) ( logits, + sample_state.log_probs, sample_state.host, ) = self.dist.recv_object( src=self.dist.prev_pp_rank, @@ -793,9 +791,8 @@ def _executor_loop_pp(self): ) if logits is not None: logits_host = torch.from_numpy(logits) - sample_state.host.logits = logits_host - sample_state.device.logits = logits_host.to( - self.device_id) + sample_state.logits_host = logits_host + sample_state.logits = logits_host.to(self.device_id) else: torch.cuda.nvtx.range_push("_handle_new_tokens_last_pp") sample_state.sampler_event.synchronize() @@ -805,16 +802,16 @@ def _executor_loop_pp(self): if not self.dist.is_second_last_pp_rank: if self.send_handles[prev_microbatch_id] is not None: self.send_handles[prev_microbatch_id].Wait() - needs_logits = ( - self._need_return_logits(scheduled_batch) - or (self._need_return_log_probs(scheduled_batch) - and sample_state.host.log_probs is not None)) - serialized_logits = sample_state.host.logits.numpy( - ) if needs_logits else None self.send_handles[ prev_microbatch_id] = self.dist.isend_object( ( - serialized_logits, + sample_state.logits_host.numpy() if + self._need_return_logits(scheduled_batch) or + (self._need_return_log_probs( + scheduled_batch) + and sample_state.log_probs is not None) + else None, + sample_state.log_probs, sample_state.host, ), dest=self.dist.next_pp_rank, @@ -1718,7 +1715,8 @@ def _insert_ngram_iter_stats( total_num_draft_tokens = 0 total_num_accepted_tokens = 0 num_requests_with_draft_tokens = 0 - for request in scheduled_requests.all_requests(): + for request in chain(scheduled_requests.context_requests, + scheduled_requests.generation_requests): num_draft_tokens = 0 if request.py_last_draft_tokens is None else len( request.py_last_draft_tokens) num_accepted_tokens = getattr(request, @@ -1789,33 +1787,38 @@ def _prepare_draft_batch( input_tokens = spec_config.get_draft_model_prompt( request.get_tokens()[beam_idx]) - def create_new_request(input_tokens): - return LlmRequest(request_id=request.py_request_id, - max_new_tokens=request.py_max_new_tokens, - input_tokens=input_tokens, - sampling_config=request.sampling_config, - is_streaming=False, - is_draft=True) - if request.max_beam_num_tokens - 1 == request.py_prompt_len: # This is the first time the draft model is seeing this request. # Prepare a context request. We discard the first token and take # the newly decoded one - this is the convention for EAGLE 2 and 3. assert num_draft_tokens == 0 - new_request = create_new_request(input_tokens) + new_request = LlmRequest( + request_id=request.py_request_id, + max_new_tokens=request.py_max_new_tokens, + input_tokens=input_tokens, + sampling_config=request.sampling_config, + is_streaming=False) + draft_batch.context_requests.append(new_request) elif num_accepted_tokens == 0: - new_request = create_new_request(input_tokens[:-1]) + new_request = LlmRequest( + request_id=request.py_request_id, + max_new_tokens=request.py_max_new_tokens, + input_tokens=input_tokens[:-1], + sampling_config=request.sampling_config, + is_streaming=False) # Explicitly add the last token so get_last_tokens() returns # the right value new_request.add_new_token(input_tokens[-1], beam_idx) new_request.state = LlmRequestState.GENERATION_IN_PROGRESS draft_batch.generation_requests.append(new_request) else: - new_request = create_new_request(input_tokens) - new_request.context_chunk_size = num_accepted_tokens + 1 - new_request.context_current_position = len( - input_tokens) - num_accepted_tokens - 1 + new_request = LlmRequest( + request_id=request.py_request_id, + max_new_tokens=request.py_max_new_tokens, + input_tokens=input_tokens, + sampling_config=request.sampling_config, + is_streaming=False) new_request.context_chunk_size = num_accepted_tokens + 1 new_request.context_current_position = len( input_tokens) - num_accepted_tokens - 1 @@ -1834,19 +1837,16 @@ def create_new_request(input_tokens): @nvtx_range("_prepare_draft_tokens") def _prepare_draft_tokens(self, scheduled_requests: ScheduledRequests): - if not self.draft_model_engine: - raise ValueError("Draft model engine is not set") - try: draft_batch = self._prepare_draft_batch(scheduled_requests) if draft_batch.batch_size == 0: return - self.draft_seq_slot_manager.prepare_resources(draft_batch) req_id_to_old_request = { req.py_request_id: req - for req in scheduled_requests.all_requests() + for req in chain(scheduled_requests.context_requests, + scheduled_requests.generation_requests) } # Disable cuda graph for the 1st draft model forward @@ -1868,7 +1868,8 @@ def _prepare_draft_tokens(self, scheduled_requests: ScheduledRequests): def _process_decoded_tokens(draft_batch): new_requests = [] - for req in draft_batch.all_requests(): + for req in chain(draft_batch.context_requests, + draft_batch.generation_requests): target_model_req = req_id_to_old_request[req.py_request_id] target_model_req.py_draft_tokens.append( req.get_last_tokens(0)) @@ -1876,8 +1877,6 @@ def _process_decoded_tokens(draft_batch): target_model_req.py_draft_tokens ) < target_model_req.py_draft_pages_allocated: new_requests.append(req) - else: - self.draft_seq_slot_manager.free_resources(req) return new_requests @@ -2115,12 +2114,14 @@ def _pause_requests(self, requests_to_pause): def _add_inflight_ids(self, scheduled_requests): """Add reqids of current requests to self.inflight_req_ids.""" - for req in scheduled_requests.all_requests(): + for req in chain(scheduled_requests.context_requests, + scheduled_requests.generation_requests): self.inflight_req_ids.insert(req.request_id) def _remove_inflight_ids(self, scheduled_requests): """Remove reqids of current requests from self.inflight_req_ids.""" - for req in scheduled_requests.all_requests(): + for req in chain(scheduled_requests.context_requests, + scheduled_requests.generation_requests): self.inflight_req_ids.erase(req.request_id) def _should_exclude_last_generation_logits(self) -> bool: diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 04a07dc502b..2c116831964 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -27,11 +27,9 @@ from .scheduler import ScheduledRequests -@dataclass(kw_only=True) +@dataclass(frozen=True, kw_only=True) class SampleStateTensors: new_tokens: torch.Tensor - logits: torch.Tensor | None = None - log_probs: torch.Tensor | None = None def values(self): return vars(self).values() @@ -41,6 +39,13 @@ def values(self): class SampleState: scheduled_requests: ScheduledRequests + logits: torch.Tensor = None + logits_host: torch.Tensor = None + + # Set when decode_async() has evaluated these to avoid computing again in update_requests() + # log_probs[request_idx][token_idx] + log_probs: list[list[float] | None] | None = None + device: SampleStateTensors = None host: SampleStateTensors = None @@ -72,12 +77,10 @@ class EarlyStopSampler(Sampler): def sample_async(self, scheduled_requests: ScheduledRequests, model_outputs) -> SampleState: - host = SampleStateTensors(logits=model_outputs['logits'], - new_tokens=torch.empty(0)) - return SampleState(scheduled_requests=scheduled_requests, host=host) + return SampleState(scheduled_requests=scheduled_requests, + logits=model_outputs['logits']) def update_requests(self, state: SampleState) -> None: - assert isinstance(state, SampleState) scheduled_requests = state.scheduled_requests assert (not scheduled_requests.generation_requests) for idx, request in enumerate(scheduled_requests.context_requests): @@ -85,7 +88,7 @@ def update_requests(self, state: SampleState) -> None: # NOTE: This is a hack: set finish reason manually and set the beam 0 request.set_finished_reason(FinishReason.LENGTH, 0) if request.py_return_context_logits: - logits = state.host.logits[idx] + logits = state.logits[idx] if logits.ndim == 1: # For BERT: Add axis to be compatible with LogitsStorage # (LogitsStorage will interpret this dim as the prompt_len which @@ -101,6 +104,8 @@ def top_k_sampling_batch(logits, top_k=50): # logits should be 2D :[batch_size, vocab_size] batch_size, vocab_size = logits.size() + raw_probs = torch.softmax(logits, dim=-1) + # get first top_k logits of each sample and their indices values, indices = torch.topk(logits, top_k, dim=-1) min_values = values[:, -1].unsqueeze(-1).expand(batch_size, vocab_size) @@ -110,18 +115,24 @@ def top_k_sampling_batch(logits, top_k=50): torch.full_like(logits, float('-inf')), logits) # compute probability distribution - softmax = torch.softmax(logits, dim=-1) + probs = torch.softmax(logits, dim=-1) # sample from the distribution and generate result of [batch_size, 1] - next_tokens = torch.multinomial(softmax, num_samples=1).squeeze(-1) - return next_tokens, softmax + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1) + token_probs = torch.gather(raw_probs, dim=1, + index=next_tokens.unsqueeze(1)).squeeze(-1) + log_probs = torch.log(token_probs) + return next_tokens, log_probs -def top_p_sampling_batch(logits: torch.Tensor, top_p: float = 0.9): +def top_p_sampling_batch(logits, top_p=0.9): logits_dim = logits.dim() if logits_dim == 1: logits = logits.unsqueeze(0) - assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]" + # logits should be 2D :[batch_size, vocab_size] + batch_size, vocab_size = logits.size() + + raw_probs = torch.softmax(logits, dim=-1) # sort the logits of each sample in descending order sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) @@ -141,82 +152,46 @@ def top_p_sampling_batch(logits: torch.Tensor, top_p: float = 0.9): logits = logits.masked_fill(indices_to_remove, float('-inf')) # compute probability distribution - softmax = torch.softmax(logits, dim=-1) + probs = torch.softmax(logits, dim=-1) # sample from the distribution and generate result of [batch_size, 1] - next_tokens = torch.multinomial(softmax, num_samples=1).squeeze(-1) - return next_tokens, softmax + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1) + token_probs = torch.gather(raw_probs, dim=1, + index=next_tokens.unsqueeze(1)).squeeze(-1) + log_probs = torch.log(token_probs) + return next_tokens, log_probs def greedy_search_sampling_batch(logits): + raw_probs = torch.softmax(logits, dim=-1) next_tokens = torch.argmax(logits, dim=-1) - softmax = torch.softmax(logits, dim=-1) - return next_tokens, softmax + token_probs = torch.gather(raw_probs, dim=1, + index=next_tokens.unsqueeze(1)).squeeze(-1) + log_probs = torch.log(token_probs) + return next_tokens, log_probs -def sample_single_request(request: LlmRequest, logits: torch.Tensor): +def decode_single_request(request: LlmRequest, logits): assert logits.dim( ) == 2 and logits.shape[0] == 1, "logits should have shape [1, vocab_size]" if request.sampling_config.top_p is not None and len( request.sampling_config.top_p) > 0: - return top_p_sampling_batch(logits, request.sampling_config.top_p[0]) + next_tokens, log_probs = top_p_sampling_batch( + logits, request.sampling_config.top_p[0]) elif request.sampling_config.top_k is not None and len( request.sampling_config.top_k) > 0: - return top_k_sampling_batch(logits, request.sampling_config.top_k[0]) + next_tokens, log_probs = top_k_sampling_batch( + logits, request.sampling_config.top_k[0]) else: - return greedy_search_sampling_batch(logits) - - -def new_tokens_slice(request: LlmRequest, beam: int, *, - size: int) -> tuple[slice, int, int]: - return slice(0, size), request.seq_slot, beam - - -def add_token(request: LlmRequest, - new_tokens: torch.Tensor, - *, - beam: int, - step: int = 0) -> int: - seq_slot = request.seq_slot - assert seq_slot is not None - new_token = int(new_tokens[step, request.seq_slot, beam]) - request.add_new_token(new_token, beam) - return new_token + next_tokens, log_probs = greedy_search_sampling_batch(logits) + return next_tokens, log_probs class TorchSampler(Sampler): - BEAM = 0 - MAX_BEAM_WIDTH = BEAM + 1 - - @dataclass(frozen=True, kw_only=True) - class Store: - new_tokens: torch.Tensor - """Shape: See cpp DecoderState.getAllNewTokens()""" - - @dataclass(frozen=True, kw_only=True) - class Args: - max_seq_len: int - max_draft_tokens: int - max_num_sequences: int - max_beam_width: int - mixed_sampler: bool - - def __init__(self, args: Args): - self.max_seq_len = args.max_seq_len - self.mixed_sampler = args.mixed_sampler - self.max_tokens = args.max_draft_tokens + 1 - assert args.max_beam_width == self.MAX_BEAM_WIDTH, "TorchSampler only supports beam_width = 1" - self.num_seq_slots = args.max_num_sequences - - # AutoDeploy build creates the sampler in inference mode, - # which would disallow in-place mutating of new_tokens. - # So, we temporarily exit inference mode. - with torch.inference_mode(False): - new_tokens = torch.zeros( - (self.max_tokens, self.num_seq_slots, self.MAX_BEAM_WIDTH), - dtype=torch.int, - device='cuda') - self.store = self.Store(new_tokens=new_tokens) + + def __init__(self, max_seq_len: int, mixed_sampler: bool = False): + self.max_seq_len = max_seq_len + self.mixed_sampler = mixed_sampler def _meet_max_token_stop_criteria(self, request: LlmRequest, num_tokens: int): @@ -224,8 +199,7 @@ def _meet_max_token_stop_criteria(self, request: LlmRequest, >= request.py_max_new_tokens) or (num_tokens >= self.max_seq_len) - @staticmethod - def _meet_stop_token_criteria(request: LlmRequest): + def _meet_stop_token_criteria(self, request: LlmRequest): if request.py_stop_words_list: assert isinstance( request.py_stop_words_list, @@ -243,163 +217,233 @@ def _meet_stop_token_criteria(request: LlmRequest): return True return False - def _handle_stop_criteria(self, request: LlmRequest, new_token: int, *, - beam: int) -> bool: + def _handle_stop_criteria(self, request: LlmRequest, new_token: int, + num_tokens: int, beam_idx: int) -> bool: """Handle stop criteria and set appropriate finish reasons and state. Returns True if generation should stop.""" if new_token == request.py_end_id: - request.finish_by_reason(FinishReason.END_ID) + request.state = LlmRequestState.GENERATION_COMPLETE + request.set_finished_reason(FinishReason.END_ID, beam_idx) return True - num_tokens = request.get_num_tokens(beam) if self._meet_max_token_stop_criteria(request, num_tokens): - request.finish_by_reason(FinishReason.LENGTH) + request.state = LlmRequestState.GENERATION_COMPLETE + request.set_finished_reason(FinishReason.LENGTH, beam_idx) return True if self._meet_stop_token_criteria(request): - request.finish_by_reason(FinishReason.STOP_WORDS) + request.state = LlmRequestState.GENERATION_COMPLETE + request.set_finished_reason(FinishReason.STOP_WORDS, beam_idx) return True return False - def handle_logits(self, request: LlmRequest, state: SampleState, *, - beam: int, count: int): - current_slice = new_tokens_slice(request, beam, size=count) - if request.py_return_generation_logits: - assert state.host.logits is not None - current_logits = state.host.logits[current_slice] + def update_requests(self, state: SampleState) -> None: + if state.sampler_event: + state.sampler_event.synchronize() + new_tokens_list = state.host.new_tokens.tolist() + scheduled_requests = state.scheduled_requests + + request_idx = 0 + token_idx = 0 + beam_idx = 0 + + def advance_idx(num_tokens=1): + nonlocal request_idx, token_idx + request_idx += 1 + token_idx += num_tokens + + def handle_logits(request: LlmRequest, tokens: list[int], count=1): + if state.logits is None: + return + if not request.py_return_generation_logits and not request.py_return_log_probs: + return + + current_slice = slice(token_idx, token_idx + count) + current_logits = state.logits[current_slice] + request.py_result.append_generation_logits(current_logits) - if request.py_return_log_probs: - assert state.host.log_probs is not None - log_probs = state.host.log_probs[request.seq_slot][beam][:count] - current_tokens = state.host.new_tokens[current_slice] + + if not request.py_return_log_probs: + return + + if state.log_probs: + log_probs = state.log_probs[request_idx] + else: + _, log_probs = greedy_search_sampling_batch(current_logits) token_log_probs = [{ - int(token): Logprob(logprob=logprob, rank=1) - } for token, logprob in zip(current_tokens, log_probs.tolist())] - assert beam == 0, "The following call relies on beam_width to be 1 - hence the list with a single element" + token: Logprob(logprob=logprob, rank=1) + } for token, logprob in zip(tokens, log_probs.tolist())] request.py_result.append_log_probs([token_log_probs]) - def process_draft_tokens(self, request: LlmRequest, - new_tokens: torch.Tensor, new_token: int) -> int: - num_accepted = 0 - for draft_token in request.py_draft_tokens: - if draft_token != new_token: - # Reject. - break - num_accepted += 1 - new_token = add_token(request, - new_tokens, - beam=self.BEAM, - step=num_accepted) - if self._handle_stop_criteria(request, new_token, beam=self.BEAM): - break - return num_accepted - - def update_requests(self, state: SampleState) -> None: - assert isinstance(state, SampleState) - if state.sampler_event: - state.sampler_event.synchronize() - new_tokens = state.host.new_tokens + if hasattr(scheduled_requests, 'chunked_requests'): + request_idx += len(scheduled_requests.chunked_requests) + token_idx += len(scheduled_requests.chunked_requests) - for req in state.scheduled_requests.context_requests: - if req.state == LlmRequestState.GENERATION_COMPLETE or req.context_remaining_length != 0: + for request in scheduled_requests.context_requests: + if request.context_remaining_length != 0: + advance_idx() continue - new_token = add_token(req, new_tokens, beam=self.BEAM) - stop = self._handle_stop_criteria(req, new_token, beam=self.BEAM) - self.handle_logits(req, state, beam=self.BEAM, count=1) - req.py_decoding_iter += 1 - for req in state.scheduled_requests.generation_requests: - if req.state == LlmRequestState.GENERATION_COMPLETE: + if request.state != LlmRequestState.GENERATION_COMPLETE: + new_token = new_tokens_list[token_idx] + num_tokens = request.add_new_token(new_token, beam_idx) + self._handle_stop_criteria(request, new_token, num_tokens, + beam_idx) + handle_logits(request, [new_token]) + request.py_decoding_iter += 1 + advance_idx() + + extend_requests = [] + generation_requests = [] + for request in scheduled_requests.generation_requests: + if len(request.py_draft_tokens) > 0: + extend_requests.append(request) + else: + generation_requests.append(request) + + for request in extend_requests: + if request.state != LlmRequestState.GENERATION_COMPLETE: + new_token = new_tokens_list[token_idx] + num_tokens = request.add_new_token(new_token, beam_idx) + if self._handle_stop_criteria(request, new_token, num_tokens, + beam_idx): + continue + + # Accept draft tokens (if we have any) if and only if they match the new + # token exactly. + num_accepted = 0 + new_tokens = [new_token] + for draft_token in request.py_draft_tokens: + if draft_token != new_token: + # Reject. + break + num_accepted += 1 + new_token = new_tokens_list[token_idx + num_accepted] + num_tokens = request.add_new_token(new_token, beam_idx) + new_tokens.append(num_tokens) # `num_tokens`->`new_token` + + if self._handle_stop_criteria(request, new_token, + num_tokens, beam_idx): + break + handle_logits(request, new_tokens, num_accepted) + request.py_decoding_iter += 1 + request.py_num_accepted_draft_tokens = num_accepted + request.py_rewind_len = request.py_draft_pages_allocated - num_accepted + advance_idx(len(request.py_draft_tokens) + 1) + + for request in generation_requests: + if request.state != LlmRequestState.GENERATION_COMPLETE: + new_token = new_tokens_list[token_idx] + num_tokens = request.add_new_token(new_token, beam_idx) + self._handle_stop_criteria(request, new_token, num_tokens, + beam_idx) + handle_logits(request, [new_token]) + request.py_decoding_iter += 1 + advance_idx() + + def _mixed_sample(self, scheduled_requests: ScheduledRequests, + model_outputs) -> SampleState: + logits = model_outputs["logits"] + log_probs = [] + new_tokens_device_array = [] + + idx = 0 + + for request in scheduled_requests.context_requests: + assert not request.py_return_context_logits, "Return context logits not supported" + token_logits = logits[idx:idx + 1, :] + new_token, probs = decode_single_request(request, token_logits) + new_tokens_device_array.append(new_token) + probs = [probs.tolist()] if request.py_return_log_probs else None + log_probs.append(probs) # Currently always beam_width=1 + idx += 1 + + for request in scheduled_requests.generation_requests: + if request.state == LlmRequestState.GENERATION_COMPLETE: continue - new_token = add_token(req, new_tokens, beam=self.BEAM) - stop = self._handle_stop_criteria(req, new_token, beam=self.BEAM) - processed = 1 - if not stop and len(req.py_draft_tokens) > 0: - num_accepted = self.process_draft_tokens( - req, new_tokens, new_token) - req.py_num_accepted_draft_tokens = num_accepted - req.py_rewind_len = req.py_draft_pages_allocated - num_accepted - processed += num_accepted - self.handle_logits(req, state, beam=self.BEAM, count=processed) - req.py_decoding_iter += 1 - - def log_probs_host(self, requests: Iterable[LlmRequest]): - """Shape: In lockstep with TRTLLMSampler: https://github.com/NVIDIA/TensorRT-LLM/blob/cea5dd1e3883b18bf50901a7f196f50a9544c28c/cpp/include/tensorrt_llm/runtime/decoderState.h#L103""" - if any(req.py_return_log_probs for req in requests): - return torch.empty( - (self.num_seq_slots, self.MAX_BEAM_WIDTH, self.max_tokens), - device="cpu", - pin_memory=True) - return None - - def gen_logits_host(self, requests: Iterable[LlmRequest], vocab_size: int): - if any(req.py_return_generation_logits for req in requests): - return torch.empty((self.max_tokens, self.num_seq_slots, - self.MAX_BEAM_WIDTH, vocab_size), - device="cpu", - pin_memory=True) - return None + assert len( + request.py_draft_tokens + ) == 0, "Speculative decoding not supported in SeparateDecoder." + token_logits = logits[idx:idx + 1, :] + new_token, probs = decode_single_request(request, token_logits) + new_tokens_device_array.append(new_token) + probs = [probs.tolist()] if request.py_return_log_probs else None + log_probs.append(probs) # Currently always beam_width=1 + idx += 1 + + new_tokens_device = torch.cat(new_tokens_device_array) + new_tokens_host = new_tokens_device.to('cpu', non_blocking=True) + sampler_event = torch.cuda.Event() + sampler_event.record() - def sample_async(self, scheduled_requests: ScheduledRequests, - model_outputs: dict[str, torch.Tensor]) -> SampleState: - requests = scheduled_requests.all_requests() - new_tokens = self.store.new_tokens - vocab_size = model_outputs["logits"].shape[-1] - log_probs_host = self.log_probs_host(requests) - gen_logits_host = self.gen_logits_host(requests, vocab_size) - self._process_requests(requests, - model_outputs, - new_tokens, - gen_logits_host=gen_logits_host, - log_probs_host=log_probs_host) - new_tokens_host = new_tokens.to(device="cpu", non_blocking=True) + return SampleState( + scheduled_requests=scheduled_requests, + logits=logits, + device=SampleStateTensors(new_tokens=new_tokens_device), + host=SampleStateTensors(new_tokens=new_tokens_host), + sampler_event=sampler_event, + log_probs=log_probs) + + def _batch_sample(self, scheduled_requests: ScheduledRequests, + model_outputs) -> SampleState: + logits = model_outputs["logits"] + new_tokens_device = torch.argmax(logits, dim=-1) + new_tokens_host = new_tokens_device.to('cpu', non_blocking=True) sampler_event = torch.cuda.Event() sampler_event.record() - return SampleState(scheduled_requests=scheduled_requests, - device=SampleStateTensors(new_tokens=new_tokens), - host=SampleStateTensors(new_tokens=new_tokens_host, - log_probs=log_probs_host, - logits=gen_logits_host), - sampler_event=sampler_event) - - def _process_requests(self, - requests: list[LlmRequest], - model_outputs: dict[str, torch.Tensor], - new_tokens: torch.Tensor, - *, - gen_logits_host: torch.Tensor | None = None, - log_probs_host: torch.Tensor | None = None): - beam = self.BEAM - offset = 0 - raw_logits = model_outputs["logits"] - - for request in requests: - steps = 1 - if len(request.py_draft_tokens) > 0: - assert not self.mixed_sampler, "Speculative decoding not supported in mixed sampler" - steps += len(request.py_draft_tokens) - logits = raw_logits[offset:offset + steps] - if self.mixed_sampler: - next_tokens, softmax = sample_single_request(request, logits) - else: - next_tokens, softmax = greedy_search_sampling_batch(logits) - current_slice = new_tokens_slice(request, beam, size=steps) - new_tokens[current_slice] = next_tokens - if "d2t" in model_outputs: # Eagle3 - new_tokens[current_slice] += model_outputs["d2t"][ - new_tokens[current_slice]] - if gen_logits_host is not None: - gen_logits_host[current_slice].copy_(logits, non_blocking=True) - if log_probs_host is not None: - assert beam == 0, "The following call relies on beam_width to be 1 - hence the unsqueeze" - token_probs = torch.gather( - softmax, dim=1, index=next_tokens.unsqueeze(1)).squeeze(-1) - log_probs = torch.log(token_probs) - log_probs_host[request.seq_slot, - beam, :steps].copy_(log_probs, non_blocking=True) - offset += steps + return SampleState( + scheduled_requests=scheduled_requests, + logits=logits, + device=SampleStateTensors(new_tokens=new_tokens_device), + host=SampleStateTensors(new_tokens=new_tokens_host), + sampler_event=sampler_event) + + def sample_async(self, scheduled_requests: ScheduledRequests, + model_outputs) -> SampleState: + if self.mixed_sampler: + return self._mixed_sample(scheduled_requests, model_outputs) + else: + return self._batch_sample(scheduled_requests, model_outputs) + + +class TorchStarAttentionSampler(TorchSampler): + + def update_one_request(self, request: LlmRequest, + new_tokens_list: list[int], logits: torch.Tensor): + beam_idx = 0 + + output_token_idx = request.output_token_idx + new_token = new_tokens_list[output_token_idx] + num_tokens = request.add_new_token(new_token, beam_idx) + + current_logits = logits[output_token_idx].unsqueeze(0) + if request.py_return_generation_logits: + request.py_result.append_generation_logits(current_logits) + if request.py_return_log_probs: + _, log_probs = greedy_search_sampling_batch(current_logits) + request.py_result.append_log_probs([[{ + new_token: + Logprob(logprob=log_probs.item(), rank=1) + }]]) + + self._handle_stop_criteria(request, new_token, num_tokens, beam_idx) + if request.state != LlmRequestState.GENERATION_COMPLETE: + request.py_decoding_iter += 1 + + def update_requests(self, state: SampleState): + if state.sampler_event: + state.sampler_event.synchronize() + new_tokens_list = state.host.new_tokens.tolist() + logits = state.logits + + for request in state.scheduled_requests.context_requests: + if request.state == LlmRequestState.GENERATION_IN_PROGRESS: + self.update_one_request(request, new_tokens_list, logits) + + for request in state.scheduled_requests.generation_requests: + self.update_one_request(request, new_tokens_list, logits) class Algorithms: @@ -412,17 +456,19 @@ def __repr__(self): return f"Algs({', '.join(algs)})" -@dataclass(kw_only=True) +@dataclass(frozen=True, kw_only=True) class SampleStateTensorsHostTRTLLM(SampleStateTensors): finished_sum: torch.Tensor finish_reasons: torch.Tensor sequence_lengths: torch.Tensor - cum_log_probs: torch.Tensor | None = None + log_probs: torch.Tensor + cum_log_probs: torch.Tensor @dataclass(kw_only=True) class SampleStateTRTLLM(SampleState): host: SampleStateTensorsHostTRTLLM + device: SampleStateTensors class TRTLLMSampler(Sampler): @@ -485,6 +531,13 @@ def _initialize_store(self): DecoderInputBuffers(self.max_num_sequences, self.executor_config.max_batch_size, self.MAX_DECODING_TOKENS, buffer_manager), + "new_tokens_device_tensor": + torch.empty(( + self.executor_config.max_batch_size, + self.executor_config.max_beam_width, + ), + dtype=torch.int, + device='cuda'), "sequence_lengths_host": torch.empty(( self.executor_config.max_batch_size, @@ -551,6 +604,7 @@ def beam_width(scheduled_requests: Iterable[LlmRequest]) -> int: def sample_async(self, scheduled_requests: ScheduledRequests, model_outputs) -> SampleStateTRTLLM: batch_size = scheduled_requests.batch_size + beam_width = self.beam_width(scheduled_requests.all_requests) self.setup_sampler_step(scheduled_requests.context_requests) @@ -580,6 +634,20 @@ def sample_async(self, scheduled_requests: ScheduledRequests, self.algs.decoder.forward_async(self.store["decoder_state"], decoding_input) + # NOTE: The following code prepares a new_tokens_device_tensor in accordance with the + # current implementation of model_engine. + # TODO: When we support speculative decoding: + # new_tokens_device_tensor should be, for speculative decoding cases: [batch, 1 + draft_len], others: [batch] + new_tokens_device_tensor = self.store[ + "new_tokens_device_tensor"][:batch_size, :beam_width] + seq_slots = [ + request.seq_slot for request in scheduled_requests.all_requests + ] + new_tokens_device_tensor.copy_( + self.store["decoder_state"].all_new_tokens[0][seq_slots], + non_blocking=True) + new_tokens_device_tensor = new_tokens_device_tensor.view(-1) + new_output_tokens = self.store["decoder_state"].all_new_tokens.to( 'cpu', non_blocking=True) finished_sum = self.store["decoder_state"].finished_sum.to( @@ -589,17 +657,16 @@ def sample_async(self, scheduled_requests: ScheduledRequests, sequence_lengths = self.store["decoder_state"].sequence_lengths.to( 'cpu', non_blocking=True) - log_probs = None - cum_log_probs = None + log_probs = torch.empty([0], dtype=torch.float, device='cpu') + cum_log_probs = torch.empty([0], dtype=torch.float, device='cpu') if any(request.py_return_log_probs - for request in scheduled_requests.all_requests()): + for request in scheduled_requests.all_requests): log_probs = self.store["decoder_state"].log_probs.to( 'cpu', non_blocking=True) cum_log_probs = self.store["decoder_state"].cum_log_probs.to( 'cpu', non_blocking=True) - device = SampleStateTensors( - new_tokens=self.store["decoder_state"].all_new_tokens) + device = SampleStateTensors(new_tokens=new_tokens_device_tensor) host = SampleStateTensorsHostTRTLLM(new_tokens=new_output_tokens, finished_sum=finished_sum, @@ -612,6 +679,7 @@ def sample_async(self, scheduled_requests: ScheduledRequests, sampler_event.record() return SampleStateTRTLLM(scheduled_requests=scheduled_requests, + logits=model_outputs["logits"], device=device, host=host, sampler_event=sampler_event) @@ -621,8 +689,7 @@ def update_requests(self, state: SampleStateTRTLLM): scheduled_requests = state.scheduled_requests assert scheduled_requests.batch_size > 0 - requests = scheduled_requests.all_requests() - beam_width = self.beam_width(requests) + beam_width = self.beam_width(scheduled_requests.all_requests) sampler_event = state.sampler_event if sampler_event: @@ -633,7 +700,7 @@ def update_requests(self, state: SampleStateTRTLLM): finish_reasons_host = state.host.finish_reasons sequence_lengths_host_data = state.host.sequence_lengths - for request in requests: + for request in scheduled_requests.all_requests: if request.is_context_init_state: continue @@ -653,20 +720,17 @@ def update_requests(self, state: SampleStateTRTLLM): seq_len - request.get_num_tokens(beam)) for step in range(num_new_tokens[beam]): - new_token = add_token(request, - new_tokens_host, - beam=beam, - step=step) + new_token = new_tokens_host[step][seq_slot][beam] + request.add_new_token(new_token, beam) if request.py_return_log_probs: - assert state.host.log_probs is not None # NOTE: Log probs with drafting has not been tested yet. begin_log_probs_offset = request.prompt_len if request.sampling_config.beam_width == 1 else 0 current_token = seq_len - request.prompt_len - num_new_tokens[ beam] + step log_probs.append({ - new_token: + new_token.item(): Logprob(logprob=state.host.log_probs[seq_slot][beam] [begin_log_probs_offset + current_token].item(), diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index 26df44874a0..9ce25061427 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from collections import namedtuple +from itertools import chain from typing import Optional from tensorrt_llm.bindings import executor as tb_executor @@ -35,8 +36,9 @@ def can_run_cuda_graph(self) -> bool: def batch_size(self) -> int: return len(self.context_requests) + len(self.generation_requests) - def all_requests(self) -> list[LlmRequest]: - return self.context_requests + self.generation_requests + @property + def all_requests(self) -> chain[LlmRequest]: + return chain(self.context_requests, self.generation_requests) class RequestScheduler(ABC): diff --git a/tensorrt_llm/_torch/pyexecutor/seq_slot_manager.py b/tensorrt_llm/_torch/pyexecutor/seq_slot_manager.py index 2dfe1737467..523a9693326 100644 --- a/tensorrt_llm/_torch/pyexecutor/seq_slot_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/seq_slot_manager.py @@ -1,3 +1,5 @@ +import itertools + from .llm_request import LlmRequest from .resource_manager import BaseResourceManager, SlotManager from .scheduler import ScheduledRequests @@ -15,8 +17,10 @@ def get_needed_resource_to_completion(self, request: LlmRequest) -> int: return 1 def prepare_resources(self, scheduled_batch: ScheduledRequests) -> None: - for llm_req in scheduled_batch.all_requests(): - if llm_req.seq_slot is None or llm_req.is_disagg_generation_transmission_complete: + for llm_req in itertools.chain(scheduled_batch.context_requests, + scheduled_batch.generation_requests): + if (llm_req.is_context_init_state and llm_req.seq_slot is None) or \ + llm_req.is_disagg_generation_transmission_complete: llm_req.seq_slot = self.slot_manager.add_slot( llm_req.request_id) diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 3ed84781036..e6183cc1528 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -10,7 +10,7 @@ from ..attention_backend import AttentionMetadata from ..pyexecutor.llm_request import LlmRequest from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager -from ..pyexecutor.sampler import TorchSampler +from ..pyexecutor.sampler import SampleState, SampleStateTensors, TorchSampler from ..pyexecutor.scheduler import ScheduledRequests from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode from .mtp import MTPSampler @@ -214,6 +214,26 @@ def get_hidden_states(self): return hidden_states +class Eagle3Sampler(TorchSampler): + + def _batch_sample(self, scheduled_requests, model_outputs) -> SampleState: + logits = model_outputs["logits"] + new_tokens_device = torch.argmax(logits, dim=-1) + if "d2t" in model_outputs: + d2t = model_outputs["d2t"] + new_tokens_device = d2t[new_tokens_device] + new_tokens_device + device = SampleStateTensors(new_tokens=new_tokens_device) + host = SampleStateTensors( + new_tokens=new_tokens_device.to('cpu', non_blocking=True)) + sampler_event = torch.cuda.Event() + sampler_event.record() + return SampleState(scheduled_requests=scheduled_requests, + logits=logits, + device=device, + host=host, + sampler_event=sampler_event) + + @dataclass class Eagle3OneModelSpecMetadata(SpecMetadata): # The hidden states @@ -279,10 +299,31 @@ def maybe_capture_hidden_states( break -class Eagle3OneModelSampler(MTPSampler): +class Eagle3Decoder(TorchSampler): + + def _batch_sample(self, scheduled_requests, model_outputs) -> SampleState: + logits = model_outputs["logits"] + new_tokens_device = torch.argmax(logits, dim=-1) + if "d2t" in model_outputs: + d2t = model_outputs["d2t"] + new_tokens_device = d2t[new_tokens_device] + new_tokens_device + new_tokens_host = new_tokens_device.to('cpu', non_blocking=True) + new_tensors_device = {"new_tokens_device": new_tokens_device} + new_tensors_host = {"new_tokens_host": new_tokens_host} + decoder_event = torch.cuda.Event() + decoder_event.record() + return SampleState(scheduled_requests=scheduled_requests, + logits=logits, + new_tensors_device=new_tensors_device, + new_tensors_host=new_tensors_host, + decoder_event=decoder_event) + + +class Eagle3OneModelDecoder(MTPSampler): - def __init__(self, args: TorchSampler.Args): - super().__init__(args, nextn=args.max_draft_tokens) + def __init__(self, max_seq_len: int, config: Eagle3Config): + super().__init__(max_seq_len, None) + self.draft_len = config.max_draft_tokens class Eagle3OneModelWorker(nn.Module): diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index f5e432690ee..25edbdae363 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -14,7 +14,7 @@ from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode -@dataclass(kw_only=True) +@dataclass(frozen=True, kw_only=True) class SampleStateTensorsMTP(SampleStateTensors): new_tokens_lens: torch.Tensor next_draft_tokens: torch.Tensor @@ -248,10 +248,12 @@ class MTPSampler(TorchSampler): SampleState = SampleStateMTP - def __init__(self, args: TorchSampler.Args, *, nextn: int): - super().__init__(args) + def __init__(self, max_seq_len: int, config: MTPConfig): + super().__init__(max_seq_len, False) self.mapping = None - self.draft_len = nextn + self.draft_len = 0 + if config is not None: + self.draft_len = config.num_nextn_predict_layers def _draft_meet_max_token_stop_criteria(self, request: LlmRequest, num_tokens: int, beam_idx: int): @@ -281,9 +283,8 @@ def update_requests(self, state: SampleStateMTP) -> None: if request.state != LlmRequestState.GENERATION_COMPLETE: new_token = new_tokens_list[idx][0] num_tokens = request.add_new_token(new_token, beam_idx) - should_stop = self._handle_stop_criteria(request, - new_token, - beam=beam_idx) + should_stop = self._handle_stop_criteria( + request, new_token, num_tokens, beam_idx) if self._draft_meet_max_token_stop_criteria( request, num_tokens, beam_idx): should_stop = True @@ -302,9 +303,8 @@ def update_requests(self, state: SampleStateMTP) -> None: for i in range(num_new_tokens): new_token = new_tokens[i] num_tokens = request.add_new_token(new_token, beam_idx) - should_stop = self._handle_stop_criteria(request, - new_token, - beam=beam_idx) + should_stop = self._handle_stop_criteria( + request, new_token, num_tokens, beam_idx) if should_stop: break if self._draft_meet_max_token_stop_criteria( @@ -344,6 +344,7 @@ def sample_async(self, scheduled_requests: ScheduledRequests, for request in scheduled_requests.context_requests: request.py_draft_tokens = [1] * self.draft_len return SampleStateMTP(scheduled_requests=scheduled_requests, + logits=model_outputs['logits'], device=device, host=host, sampler_event=sampler_event) diff --git a/tensorrt_llm/_torch/speculative/utils.py b/tensorrt_llm/_torch/speculative/utils.py index 85b2bf46e9c..3dd49bb108f 100644 --- a/tensorrt_llm/_torch/speculative/utils.py +++ b/tensorrt_llm/_torch/speculative/utils.py @@ -1,9 +1,6 @@ -from tensorrt_llm._torch.pyexecutor.sampler import TorchSampler -from tensorrt_llm._torch.speculative.interface import SpecConfig - from .draft_target import DraftTargetSpecMetadata -from .eagle3 import (Eagle3OneModelSampler, Eagle3OneModelSpecMetadata, - Eagle3OneModelWorker, Eagle3ResourceManager, +from .eagle3 import (Eagle3OneModelDecoder, Eagle3OneModelSpecMetadata, + Eagle3OneModelWorker, Eagle3ResourceManager, Eagle3Sampler, Eagle3SpecMetadata) from .mtp import (MTPEagleWorker, MTPHiddenStatesManager, MTPSampler, MTPSpecMetadata, MTPWorker) @@ -80,17 +77,15 @@ def get_spec_resource_manager(spec_config, return None -def get_spec_decoder(sampler_args: TorchSampler.Args, spec_config: SpecConfig): +def get_spec_decoder(max_seq_len, spec_config): if spec_config.spec_dec_mode.is_mtp(): - return MTPSampler(sampler_args, - nextn=spec_config.num_nextn_predict_layers) - if spec_config.spec_dec_mode.is_eagle3(): - # TorchSampler handles Eagle3 gracefully, by integrating d2t into the sampling process - return TorchSampler(sampler_args) - if spec_config.spec_dec_mode.is_eagle3_one_model(): - return Eagle3OneModelSampler(sampler_args) - raise ValueError( - f"Unsupported speculative decoding mode: {spec_config.spec_dec_mode}") + return MTPSampler(max_seq_len, spec_config) + elif spec_config.spec_dec_mode.is_eagle3(): + return Eagle3Sampler(max_seq_len) + elif spec_config.spec_dec_mode.is_eagle3_one_model(): + return Eagle3OneModelDecoder(max_seq_len, spec_config) + else: + return None def get_num_spec_layers(spec_config):