Skip to content

Commit 73ad66d

Browse files
committed
update another test for overlap
Signed-off-by: Izzy Putterman <[email protected]>
1 parent 187d93f commit 73ad66d

File tree

3 files changed

+16
-18
lines changed

3 files changed

+16
-18
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,12 +1203,12 @@ def _executor_loop_overlap(self):
12031203
self._process_previous_batch()
12041204
self.previous_batch: Optional[BatchState] = None
12051205

1206-
if self.enable_iter_perf_stats:
1207-
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
1208-
'num_ctx_tokens']
1209-
12101206
if scheduled_batch.batch_size > 0:
12111207

1208+
if self.enable_iter_perf_stats:
1209+
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
1210+
'num_ctx_tokens']
1211+
12121212
self.previous_batch = BatchState(
12131213
sample_state=sample_state,
12141214
iter_start_time=iter_start_time,

tests/integration/defs/llmapi/test_llm_api_connector.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@ def test_connector_simple(enforce_single_worker, model_with_connector,
8787
assert len(scheduler.update_state_after_alloc.call_args.args[1]) == 1
8888

8989
# With the overlap scheduler, we generate one extra token.
90-
assert scheduler.build_connector_meta.call_count == NUM_TOKENS + int(
91-
use_overlap_scheduler)
90+
assert scheduler.build_connector_meta.call_count == NUM_TOKENS
9291

9392
# We should have a single `SchedulerOutput` per forward pass.
9493
for i, call in enumerate(scheduler.build_connector_meta.call_args_list):
@@ -108,8 +107,7 @@ def test_connector_simple(enforce_single_worker, model_with_connector,
108107
assert len(scheduler_output.cached_requests[0].new_tokens) == 1
109108

110109
# We call `start_load_kv` once at the beginning of each forward pass.
111-
assert worker.start_load_kv.call_count == NUM_TOKENS + int(
112-
use_overlap_scheduler)
110+
assert worker.start_load_kv.call_count == NUM_TOKENS
113111

114112
# Only called once when the request is received.
115113
assert scheduler.get_num_new_matched_tokens.call_count == 1
@@ -118,19 +116,16 @@ def test_connector_simple(enforce_single_worker, model_with_connector,
118116
for call in worker.wait_for_layer_load.call_args_list) + 1
119117

120118
# Called num_layers * num_forward_passes times.
121-
assert worker.wait_for_layer_load.call_count == num_layers * (
122-
NUM_TOKENS + int(use_overlap_scheduler))
123-
assert worker.save_kv_layer.call_count == num_layers * (
124-
NUM_TOKENS + int(use_overlap_scheduler))
119+
assert worker.wait_for_layer_load.call_count == num_layers * (NUM_TOKENS)
120+
assert worker.save_kv_layer.call_count == num_layers * (NUM_TOKENS)
125121

126122
for i, call in enumerate(worker.wait_for_layer_load.call_args_list):
127123
assert call.args[0] == i % num_layers
128124

129125
for i, call in enumerate(worker.save_kv_layer.call_args_list):
130126
assert call.args[0] == i % num_layers
131127

132-
assert worker.wait_for_save.call_count == NUM_TOKENS + int(
133-
use_overlap_scheduler)
128+
assert worker.wait_for_save.call_count == NUM_TOKENS
134129

135130
assert scheduler.request_finished.call_count == 1
136131

@@ -238,8 +233,7 @@ def test_connector_scheduler_output(enforce_single_worker, model_with_connector,
238233
NUM_INPUT_TOKENS / BLOCK_SIZE)
239234

240235
# Additional token when using the overlap scheduler.
241-
assert scheduler.build_connector_meta.call_count == NUM_TOKENS + int(
242-
use_overlap_scheduler)
236+
assert scheduler.build_connector_meta.call_count == NUM_TOKENS
243237

244238
for i, call in enumerate(scheduler.build_connector_meta.call_args_list):
245239
sched_output = call.args[0]

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,10 @@ def test_llm_reward_model():
174174

175175

176176
def test_llm_perf_metrics():
177-
llm = LLM(model=llama_model_path, kv_cache_config=global_kvcache_config)
177+
disable_overlap_scheduler = False
178+
llm = LLM(model=llama_model_path,
179+
kv_cache_config=global_kvcache_config,
180+
disable_overlap_scheduler=disable_overlap_scheduler)
178181
sampling_params = SamplingParams(max_tokens=10, return_perf_metrics=True)
179182
outputs = llm.generate(prompts, sampling_params)
180183
assert outputs[0].outputs[0].request_perf_metrics is not None
@@ -194,7 +197,8 @@ def test_llm_perf_metrics():
194197
assert kv_cache_metrics.kv_cache_hit_rate == 0
195198

196199
assert perf_metrics.first_iter is not None
197-
assert perf_metrics.iter - perf_metrics.first_iter == sampling_params.max_tokens - 1
200+
assert perf_metrics.iter - perf_metrics.first_iter == sampling_params.max_tokens - (
201+
1 if disable_overlap_scheduler else 2)
198202
assert perf_metrics.last_iter == perf_metrics.iter
199203

200204

0 commit comments

Comments
 (0)