|
11 | 11 | import weakref
|
12 | 12 | from collections import deque, namedtuple
|
13 | 13 | from contextlib import contextmanager
|
14 |
| -from typing import Dict, List, Optional, Tuple, Union |
| 14 | +from typing import Dict, List, Optional, Union |
15 | 15 |
|
16 | 16 | import torch
|
17 | 17 |
|
|
30 | 30 | from tensorrt_llm.logger import logger
|
31 | 31 |
|
32 | 32 | from ..distributed import Distributed
|
33 |
| -from ..speculative.drafter import Drafter, create_drafter |
| 33 | +from ..speculative.drafter import Drafter |
34 | 34 | from .kv_cache_transceiver import KvCacheTransceiver
|
35 | 35 | from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
|
36 | 36 | LlmResponse, executor_request_to_llm_request)
|
@@ -305,7 +305,7 @@ def __init__(self,
|
305 | 305 | if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"):
|
306 | 306 | self.event_loop = trace_func(self.event_loop)
|
307 | 307 |
|
308 |
| - if self.draft_model_engine is not None: |
| 308 | + if self.drafter is not None: |
309 | 309 | if self.event_loop.__name__ != self._executor_loop.__name__:
|
310 | 310 | raise NotImplementedError(
|
311 | 311 | "Drafting is not supported for selected executor loop. "
|
@@ -918,8 +918,7 @@ def _executor_loop(self):
|
918 | 918 |
|
919 | 919 | self._pad_attention_dp_dummy_request()
|
920 | 920 |
|
921 |
| - if self.draft_model_engine is not None or hasattr( |
922 |
| - self, 'drafter') and self.drafter is not None: |
| 921 | + if self.drafter is not None: |
923 | 922 | self._prepare_draft_requests(self.active_requests)
|
924 | 923 |
|
925 | 924 | scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
|
@@ -959,23 +958,9 @@ def _executor_loop(self):
|
959 | 958 | scheduled_batch)
|
960 | 959 |
|
961 | 960 | self.resource_manager.prepare_resources(scheduled_batch)
|
962 |
| - if self.draft_model_engine is not None and self.drafter is None: |
963 |
| - spec_resource_manager = self.resource_manager.get_resource_manager( |
964 |
| - ResourceManagerType.SPEC_RESOURCE_MANAGER) |
965 |
| - self.drafter = create_drafter( |
966 |
| - spec_decoding_mode=self.model_engine.spec_config. |
967 |
| - spec_dec_mode, |
968 |
| - spec_config=self.model_engine.spec_config, |
969 |
| - draft_model_engine=self.draft_model_engine, |
970 |
| - max_draft_tokens=self.max_draft_tokens, |
971 |
| - draft_seq_slot_manager=self.draft_seq_slot_manager, |
972 |
| - sampler=self.sampler, |
973 |
| - resource_manager=self.resource_manager, |
974 |
| - spec_resource_manager=spec_resource_manager, |
975 |
| - ) |
976 |
| - |
977 | 961 | if self.drafter is not None:
|
978 |
| - self.drafter.prepare_draft_tokens(scheduled_batch) |
| 962 | + self.drafter.prepare_draft_tokens( |
| 963 | + scheduled_batch, self.resource_manager) |
979 | 964 |
|
980 | 965 | if self.kv_cache_transceiver:
|
981 | 966 | # For generation requests which have completed KV cache transfer
|
@@ -1780,188 +1765,6 @@ def _update_requests(self, sample_state: SampleState):
|
1780 | 1765 | logger.error(f"Encountered an error in sampling: {error_msg}")
|
1781 | 1766 | self._handle_errors(error_msg)
|
1782 | 1767 |
|
1783 |
| - @nvtx_range("_prepare_draft_batch") |
1784 |
| - def _prepare_draft_batch( |
1785 |
| - self, scheduled_requests: ScheduledRequests |
1786 |
| - ) -> Tuple[ScheduledRequests, Dict[int, LlmRequest]]: |
1787 |
| - """ |
1788 |
| - Prepares a batch for the draft model engine. Draft tokens are only produced |
1789 |
| - for generation requests. |
1790 |
| -
|
1791 |
| - The requests are prepared as follows: |
1792 |
| - 1. The first time the draft engine sees a request, it's a context request. |
1793 |
| - 2. Otherwise, if draft tokens were accepted on the last target model decoding |
1794 |
| - step, it's a chunked context request (we process all the accepted tokens together). |
1795 |
| - 3. Otherwise, it's a generation request. |
1796 |
| - """ |
1797 |
| - try: |
1798 |
| - draft_batch = ScheduledRequests() |
1799 |
| - |
1800 |
| - for request in scheduled_requests.generation_requests: |
1801 |
| - if request.py_draft_pages_allocated == 0: |
1802 |
| - # No space for draft tokens. |
1803 |
| - continue |
1804 |
| - |
1805 |
| - # Stop drafting when we hit the max seqlen. We still need dummy draft |
1806 |
| - # tokens attached to the requests to make sure everything works properly |
1807 |
| - # with CUDA graph. These dummy tokens are already added by |
1808 |
| - # _prepare_draft_requests to make the KV cache/scheduler aware of the fact |
1809 |
| - # that we want to do spec decoding, so no need to do anything else here. |
1810 |
| - # This makes the perf for this case suboptimal, but that's OK - this is |
1811 |
| - # a corner case for weird models like the llama 3.1 8b EAGLE3 implementation. |
1812 |
| - if request.max_beam_num_tokens - 1 >= self.draft_model_engine.max_seq_len: |
1813 |
| - continue |
1814 |
| - |
1815 |
| - num_draft_tokens = len( |
1816 |
| - request.py_last_draft_tokens |
1817 |
| - ) if request.py_last_draft_tokens is not None else 0 |
1818 |
| - request.py_draft_tokens = [] |
1819 |
| - |
1820 |
| - num_accepted_tokens = request.py_num_accepted_draft_tokens |
1821 |
| - num_rejected_tokens = num_draft_tokens - num_accepted_tokens |
1822 |
| - assert num_rejected_tokens >= 0 |
1823 |
| - |
1824 |
| - spec_config = self.model_engine.spec_config |
1825 |
| - beam_idx = 0 |
1826 |
| - input_tokens = spec_config.get_draft_model_prompt( |
1827 |
| - request.get_tokens()[beam_idx]) |
1828 |
| - |
1829 |
| - def create_new_request(input_tokens): |
1830 |
| - return LlmRequest( |
1831 |
| - request_id=request.py_request_id, |
1832 |
| - max_new_tokens=request.py_max_new_tokens, |
1833 |
| - input_tokens=input_tokens, |
1834 |
| - sampling_config=request.sampling_config, |
1835 |
| - return_perf_metrics=request.return_perf_metrics, |
1836 |
| - is_streaming=False, |
1837 |
| - is_draft=True) |
1838 |
| - |
1839 |
| - if request.max_beam_num_tokens - 1 == request.py_prompt_len: |
1840 |
| - # This is the first time the draft model is seeing this request. |
1841 |
| - # Prepare a context request. We discard the first token and take |
1842 |
| - # the newly decoded one - this is the convention for EAGLE 2 and 3. |
1843 |
| - new_request = create_new_request(input_tokens) |
1844 |
| - draft_batch.context_requests.append(new_request) |
1845 |
| - elif num_accepted_tokens == 0: |
1846 |
| - new_request = create_new_request(input_tokens[:-1]) |
1847 |
| - # Explicitly add the last token so get_last_tokens() returns |
1848 |
| - # the right value |
1849 |
| - new_request.add_new_token(input_tokens[-1], beam_idx) |
1850 |
| - new_request.state = LlmRequestState.GENERATION_IN_PROGRESS |
1851 |
| - draft_batch.generation_requests.append(new_request) |
1852 |
| - else: |
1853 |
| - new_request = create_new_request(input_tokens) |
1854 |
| - new_request.context_chunk_size = num_accepted_tokens + 1 |
1855 |
| - new_request.context_current_position = len( |
1856 |
| - input_tokens) - num_accepted_tokens - 1 |
1857 |
| - new_request.context_chunk_size = num_accepted_tokens + 1 |
1858 |
| - new_request.context_current_position = len( |
1859 |
| - input_tokens) - num_accepted_tokens - 1 |
1860 |
| - |
1861 |
| - draft_batch.context_requests.append(new_request) |
1862 |
| - |
1863 |
| - new_request.py_stop_words_list = request.py_stop_words_list |
1864 |
| - |
1865 |
| - return draft_batch |
1866 |
| - |
1867 |
| - except Exception as e: |
1868 |
| - traceback.print_exc() |
1869 |
| - error_msg = str(e) |
1870 |
| - logger.error(f"Encountered an error in decode: {error_msg}") |
1871 |
| - self._handle_errors(error_msg) |
1872 |
| - |
1873 |
| - @nvtx_range("_prepare_draft_tokens") |
1874 |
| - def _prepare_draft_tokens(self, scheduled_requests: ScheduledRequests): |
1875 |
| - if not self.draft_model_engine: |
1876 |
| - raise ValueError("Draft model engine is not set") |
1877 |
| - |
1878 |
| - try: |
1879 |
| - draft_batch = self._prepare_draft_batch(scheduled_requests) |
1880 |
| - |
1881 |
| - if draft_batch.batch_size == 0: |
1882 |
| - return |
1883 |
| - self.draft_seq_slot_manager.prepare_resources(draft_batch) |
1884 |
| - |
1885 |
| - req_id_to_old_request = { |
1886 |
| - req.py_request_id: req |
1887 |
| - for req in scheduled_requests.all_requests() |
1888 |
| - } |
1889 |
| - |
1890 |
| - # Disable cuda graph for the 1st draft model forward |
1891 |
| - if self.model_engine.spec_config.spec_dec_mode.needs_kv_cache_recompute( |
1892 |
| - ): |
1893 |
| - with self.draft_model_engine.no_cuda_graph(): |
1894 |
| - outputs = self.draft_model_engine.forward( |
1895 |
| - draft_batch, self.resource_manager) |
1896 |
| - else: |
1897 |
| - outputs = self.draft_model_engine.forward( |
1898 |
| - draft_batch, self.resource_manager) |
1899 |
| - if hasattr(self.draft_model_engine.model.model, 'd2t'): |
1900 |
| - outputs['d2t'] = self.draft_model_engine.model.model.d2t.data |
1901 |
| - |
1902 |
| - sample_state = self._sample_async(draft_batch, outputs) |
1903 |
| - previous_batch = sample_state |
1904 |
| - |
1905 |
| - self._update_request_states(draft_batch) |
1906 |
| - |
1907 |
| - def _process_decoded_tokens(draft_batch): |
1908 |
| - new_requests = [] |
1909 |
| - for req in draft_batch.all_requests(): |
1910 |
| - target_model_req = req_id_to_old_request[req.py_request_id] |
1911 |
| - target_model_req.py_draft_tokens.append( |
1912 |
| - req.get_last_tokens(0)) |
1913 |
| - if req.state != LlmRequestState.GENERATION_COMPLETE and len( |
1914 |
| - target_model_req.py_draft_tokens |
1915 |
| - ) < target_model_req.py_draft_pages_allocated: |
1916 |
| - new_requests.append(req) |
1917 |
| - else: |
1918 |
| - self.draft_seq_slot_manager.free_resources(req) |
1919 |
| - |
1920 |
| - return new_requests |
1921 |
| - |
1922 |
| - # The TRTLLM attention kernels cannot handle generation requests with |
1923 |
| - # different seqlens. No issues with flashinfer, should we look into removing |
1924 |
| - # this? Just needs proper kernel support. |
1925 |
| - def _pad_to_max_draft_tokens(): |
1926 |
| - for req in scheduled_requests.generation_requests: |
1927 |
| - max_draft_len = self.max_draft_len |
1928 |
| - num_draft_tokens = len(req.py_draft_tokens) |
1929 |
| - req.py_draft_tokens.extend( |
1930 |
| - 0 for _ in range(max_draft_len - num_draft_tokens)) |
1931 |
| - |
1932 |
| - draft_batch.generation_requests = draft_batch.context_requests + draft_batch.generation_requests |
1933 |
| - draft_batch.context_requests = [] |
1934 |
| - |
1935 |
| - for i in range(self.max_draft_len - 1): |
1936 |
| - if len(draft_batch.generation_requests) == 0: |
1937 |
| - break |
1938 |
| - |
1939 |
| - outputs = self.draft_model_engine.forward( |
1940 |
| - draft_batch, |
1941 |
| - self.resource_manager, |
1942 |
| - new_tensors_device=previous_batch.device) |
1943 |
| - |
1944 |
| - if hasattr(self.draft_model_engine.model.model, 'd2t'): |
1945 |
| - outputs[ |
1946 |
| - 'd2t'] = self.draft_model_engine.model.model.d2t.data |
1947 |
| - sample_state = self._sample_async(draft_batch, outputs) |
1948 |
| - self._update_request_states(draft_batch) |
1949 |
| - self._update_requests(previous_batch) |
1950 |
| - new_requests = _process_decoded_tokens( |
1951 |
| - previous_batch.scheduled_requests) |
1952 |
| - draft_batch.generation_requests = new_requests |
1953 |
| - previous_batch = sample_state |
1954 |
| - self._update_requests(previous_batch) |
1955 |
| - new_requests = _process_decoded_tokens( |
1956 |
| - previous_batch.scheduled_requests) |
1957 |
| - _pad_to_max_draft_tokens() |
1958 |
| - |
1959 |
| - except Exception as e: |
1960 |
| - traceback.print_exc() |
1961 |
| - error_msg = str(e) |
1962 |
| - logger.error(f"Encountered an error in decode: {error_msg}") |
1963 |
| - self._handle_errors(error_msg) |
1964 |
| - |
1965 | 1768 | def _handle_errors(self, error_msg: Optional[str] = None):
|
1966 | 1769 | error_responses = {}
|
1967 | 1770 | error_msg = error_msg or "error"
|
|
0 commit comments