@@ -87,8 +87,7 @@ def test_connector_simple(enforce_single_worker, model_with_connector,
87
87
assert len (scheduler .update_state_after_alloc .call_args .args [1 ]) == 1
88
88
89
89
# With the overlap scheduler, we generate one extra token.
90
- assert scheduler .build_connector_meta .call_count == NUM_TOKENS + int (
91
- use_overlap_scheduler )
90
+ assert scheduler .build_connector_meta .call_count == NUM_TOKENS
92
91
93
92
# We should have a single `SchedulerOutput` per forward pass.
94
93
for i , call in enumerate (scheduler .build_connector_meta .call_args_list ):
@@ -108,8 +107,7 @@ def test_connector_simple(enforce_single_worker, model_with_connector,
108
107
assert len (scheduler_output .cached_requests [0 ].new_tokens ) == 1
109
108
110
109
# We call `start_load_kv` once at the beginning of each forward pass.
111
- assert worker .start_load_kv .call_count == NUM_TOKENS + int (
112
- use_overlap_scheduler )
110
+ assert worker .start_load_kv .call_count == NUM_TOKENS
113
111
114
112
# Only called once when the request is received.
115
113
assert scheduler .get_num_new_matched_tokens .call_count == 1
@@ -118,19 +116,16 @@ def test_connector_simple(enforce_single_worker, model_with_connector,
118
116
for call in worker .wait_for_layer_load .call_args_list ) + 1
119
117
120
118
# Called num_layers * num_forward_passes times.
121
- assert worker .wait_for_layer_load .call_count == num_layers * (
122
- NUM_TOKENS + int (use_overlap_scheduler ))
123
- assert worker .save_kv_layer .call_count == num_layers * (
124
- NUM_TOKENS + int (use_overlap_scheduler ))
119
+ assert worker .wait_for_layer_load .call_count == num_layers * (NUM_TOKENS )
120
+ assert worker .save_kv_layer .call_count == num_layers * (NUM_TOKENS )
125
121
126
122
for i , call in enumerate (worker .wait_for_layer_load .call_args_list ):
127
123
assert call .args [0 ] == i % num_layers
128
124
129
125
for i , call in enumerate (worker .save_kv_layer .call_args_list ):
130
126
assert call .args [0 ] == i % num_layers
131
127
132
- assert worker .wait_for_save .call_count == NUM_TOKENS + int (
133
- use_overlap_scheduler )
128
+ assert worker .wait_for_save .call_count == NUM_TOKENS
134
129
135
130
assert scheduler .request_finished .call_count == 1
136
131
@@ -238,8 +233,7 @@ def test_connector_scheduler_output(enforce_single_worker, model_with_connector,
238
233
NUM_INPUT_TOKENS / BLOCK_SIZE )
239
234
240
235
# Additional token when using the overlap scheduler.
241
- assert scheduler .build_connector_meta .call_count == NUM_TOKENS + int (
242
- use_overlap_scheduler )
236
+ assert scheduler .build_connector_meta .call_count == NUM_TOKENS
243
237
244
238
for i , call in enumerate (scheduler .build_connector_meta .call_args_list ):
245
239
sched_output = call .args [0 ]
0 commit comments