Skip to content

Commit 2024ab2

Browse files
evezhierUbuntu
authored andcommitted
patch executor queue mocks
Signed-off-by: Olya Kozlova <[email protected]>
1 parent 8e7876d commit 2024ab2

File tree

4 files changed

+46
-25
lines changed

4 files changed

+46
-25
lines changed

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -163,23 +163,18 @@ def _can_process_attention_dp_request(
163163

164164
def _get_request_id(self):
165165
# (next_request_id + 1) % UINT64_MAX
166+
current_id = self.next_request_id
166167
self.next_request_id = (self.next_request_id + 1) & ((1 << 64) - 1)
167-
return self.next_request_id
168+
return current_id
168169

169170
def _generate_child_request_ids(
170171
self, request: ExecutorRequest) -> List[int] | None:
171172
""" Generate child request IDs if needed. """
172173
child_req_ids = None
173-
sampling_config = request.sampling_config
174-
beam_width = (sampling_config.beam_width
175-
if sampling_config.beam_width else 1)
176-
num_return_sequences = (sampling_config.num_return_sequences
177-
if sampling_config.num_return_sequences else 1)
178-
179-
# Create child requests if beam_width == 1 and num_return_sequences > 1.
180-
if beam_width == 1 and num_return_sequences > 1:
174+
num_children = self._get_num_child_requests(request)
175+
if num_children > 0:
181176
child_req_ids = []
182-
for _ in range(num_return_sequences - 1):
177+
for _ in range(num_children):
183178
child_req_id = self._get_request_id()
184179
if self.enable_iter_perf_stats:
185180
self.start_times[child_req_id] = time.time()
@@ -599,8 +594,8 @@ def _merge_requests(self, new_requests: list[RequestQueueItem]):
599594
req_item.id, req_item.request, req_item.child_req_ids,
600595
self._should_exclude_last_generation_logits())
601596
req_with_children.append(req)
602-
if req.children:
603-
req_with_children.extend(req.children)
597+
if req.child_requests:
598+
req_with_children.extend(req.child_requests)
604599
return req_with_children
605600

606601
def _merge_star_attention_requests(self,

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def __init__(
334334
return_log_probs, return_context_logits,
335335
return_generation_logits,
336336
exclude_last_generation_logits)
337-
self.children = []
337+
self.child_requests = []
338338

339339
def is_generation_only_request(self):
340340
return self.py_llm_request_type == LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY
@@ -377,14 +377,14 @@ def create_child_request(self, child_id):
377377
py_request.py_batch_idx = None
378378
py_request.py_seq_slot = None
379379

380-
py_request.children = []
380+
py_request.child_requests = []
381381

382382
assert py_request.is_child
383383
assert py_request.request_id == child.request_id
384384
assert py_request.parent_request_id == self.request_id
385385
assert py_request.sampling_config.random_seed != self.sampling_config.random_seed
386386

387-
self.children.append(py_request)
387+
self.child_requests.append(py_request)
388388

389389

390390
def convert_wordlist(word_list) -> List[List[int]]:

tests/unittest/_torch/test_best_of_n.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@ def test_create_child_request(n: int):
6666
parent.request_id + parent.sampling_config.num_return_sequences):
6767
parent.create_child_request(child_id)
6868

69-
assert len(
70-
parent.children) == parent.sampling_config.num_return_sequences - 1
69+
assert len(parent.child_requests
70+
) == parent.sampling_config.num_return_sequences - 1
7171

72-
for ind, child in enumerate(parent.children):
72+
for ind, child in enumerate(parent.child_requests):
7373
assert child.request_id == ind + parent.request_id + 1
7474
assert child.py_request_id == child.request_id
7575
assert child.parent_request_id == parent.request_id
@@ -88,7 +88,7 @@ def test_create_child_request(n: int):
8888
assert child.get_tokens() == parent.get_tokens()
8989
assert child.get_tokens() is not parent.get_tokens()
9090

91-
assert child.children == []
91+
assert child.child_requests == []
9292

9393

9494
@force_ampere # Save H100 resource

tests/unittest/_torch/test_executor_request_queue.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ def test_enqueue_requests(executor_queue):
7575
"""Test enqueuing multiple requests."""
7676
mock_requests = [Mock(), Mock(), Mock()]
7777

78-
with patch('time.time', return_value=1234.5):
78+
with (patch('time.time', return_value=1234.5),
79+
patch.object(executor_queue, '_generate_child_request_ids')):
7980
req_ids = executor_queue.enqueue_requests(mock_requests) # type: ignore
8081

8182
assert len(req_ids) == 3
@@ -92,7 +93,8 @@ def test_enqueue_request_single(executor_queue):
9293
"""Test enqueuing a single request."""
9394
mock_request = Mock()
9495

95-
with patch('time.time', return_value=1234.5):
96+
with (patch('time.time', return_value=1234.5),
97+
patch.object(executor_queue, '_generate_child_request_ids')):
9698
req_id = executor_queue.enqueue_request(mock_request)
9799

98100
assert req_id == 8
@@ -104,8 +106,8 @@ def test_enqueue_request_with_query(executor_queue):
104106
"""Test enqueuing a request with query data."""
105107
mock_request = Mock()
106108
query_data = [1, 2, 3, 4]
107-
108-
req_id = executor_queue.enqueue_request(mock_request, query=query_data)
109+
with patch.object(executor_queue, '_generate_child_request_ids'):
110+
req_id = executor_queue.enqueue_request(mock_request, query=query_data)
109111

110112
assert req_id == 8
111113

@@ -115,6 +117,31 @@ def test_enqueue_request_with_query(executor_queue):
115117
assert item.request == mock_request
116118

117119

120+
@pytest.mark.parametrize("n_children", [0, 1, 2])
121+
def test_enqueue_request_with_child_ids(executor_queue, n_children):
122+
"""Test enqueuing a request with query data."""
123+
mock_request = Mock()
124+
query_data = [1, 2, 3, 4]
125+
with patch.object(executor_queue,
126+
'_get_num_child_requests') as mock_children:
127+
mock_children.return_value = n_children
128+
req_id = executor_queue.enqueue_request(mock_request, query=query_data)
129+
130+
assert req_id == 8
131+
132+
# Verify the item was enqueued with child ids
133+
item = executor_queue.request_queue.get_nowait()
134+
assert item.id == req_id
135+
assert item.request == mock_request
136+
if n_children == 0:
137+
assert item.child_req_ids is None
138+
else:
139+
assert item.child_req_ids is not None
140+
assert len(item.child_req_ids) == n_children
141+
assert item.child_req_ids == list(
142+
range(1 + req_id, 1 + req_id + n_children))
143+
144+
118145
def test_enqueue_cancel_request(executor_queue):
119146
"""Test enqueuing a cancel request."""
120147
req_id = 42
@@ -253,11 +280,10 @@ def test_validate_and_filter_requests(executor_queue):
253280
)
254281
def test_merge_requests_default(mock_convert, executor_queue):
255282
"""Test merging requests with default configuration."""
256-
mock_llm_request = Mock()
283+
mock_llm_request = Mock(child_requests=[])
257284
mock_convert.return_value = mock_llm_request
258285

259286
requests = [RequestQueueItem(1, Mock()), RequestQueueItem(2, Mock())]
260-
261287
result = executor_queue._merge_requests(requests)
262288

263289
assert len(result) == 2

0 commit comments

Comments
 (0)