@@ -75,7 +75,8 @@ def test_enqueue_requests(executor_queue):
75
75
"""Test enqueuing multiple requests."""
76
76
mock_requests = [Mock (), Mock (), Mock ()]
77
77
78
- with patch ('time.time' , return_value = 1234.5 ):
78
+ with (patch ('time.time' , return_value = 1234.5 ),
79
+ patch .object (executor_queue , '_generate_child_request_ids' )):
79
80
req_ids = executor_queue .enqueue_requests (mock_requests ) # type: ignore
80
81
81
82
assert len (req_ids ) == 3
@@ -92,7 +93,8 @@ def test_enqueue_request_single(executor_queue):
92
93
"""Test enqueuing a single request."""
93
94
mock_request = Mock ()
94
95
95
- with patch ('time.time' , return_value = 1234.5 ):
96
+ with (patch ('time.time' , return_value = 1234.5 ),
97
+ patch .object (executor_queue , '_generate_child_request_ids' )):
96
98
req_id = executor_queue .enqueue_request (mock_request )
97
99
98
100
assert req_id == 8
@@ -104,8 +106,8 @@ def test_enqueue_request_with_query(executor_queue):
104
106
"""Test enqueuing a request with query data."""
105
107
mock_request = Mock ()
106
108
query_data = [1 , 2 , 3 , 4 ]
107
-
108
- req_id = executor_queue .enqueue_request (mock_request , query = query_data )
109
+ with patch . object ( executor_queue , '_generate_child_request_ids' ):
110
+ req_id = executor_queue .enqueue_request (mock_request , query = query_data )
109
111
110
112
assert req_id == 8
111
113
@@ -115,6 +117,31 @@ def test_enqueue_request_with_query(executor_queue):
115
117
assert item .request == mock_request
116
118
117
119
120
+ @pytest .mark .parametrize ("n_children" , [0 , 1 , 2 ])
121
+ def test_enqueue_request_with_child_ids (executor_queue , n_children ):
122
+ """Test enqueuing a request with query data."""
123
+ mock_request = Mock ()
124
+ query_data = [1 , 2 , 3 , 4 ]
125
+ with patch .object (executor_queue ,
126
+ '_get_num_child_requests' ) as mock_children :
127
+ mock_children .return_value = n_children
128
+ req_id = executor_queue .enqueue_request (mock_request , query = query_data )
129
+
130
+ assert req_id == 8
131
+
132
+ # Verify the item was enqueued with child ids
133
+ item = executor_queue .request_queue .get_nowait ()
134
+ assert item .id == req_id
135
+ assert item .request == mock_request
136
+ if n_children == 0 :
137
+ assert item .child_req_ids is None
138
+ else :
139
+ assert item .child_req_ids is not None
140
+ assert len (item .child_req_ids ) == n_children
141
+ assert item .child_req_ids == list (
142
+ range (1 + req_id , 1 + req_id + n_children ))
143
+
144
+
118
145
def test_enqueue_cancel_request (executor_queue ):
119
146
"""Test enqueuing a cancel request."""
120
147
req_id = 42
@@ -253,11 +280,10 @@ def test_validate_and_filter_requests(executor_queue):
253
280
)
254
281
def test_merge_requests_default (mock_convert , executor_queue ):
255
282
"""Test merging requests with default configuration."""
256
- mock_llm_request = Mock ()
283
+ mock_llm_request = Mock (child_requests = [] )
257
284
mock_convert .return_value = mock_llm_request
258
285
259
286
requests = [RequestQueueItem (1 , Mock ()), RequestQueueItem (2 , Mock ())]
260
-
261
287
result = executor_queue ._merge_requests (requests )
262
288
263
289
assert len (result ) == 2
0 commit comments