Skip to content

Commit c2a8ac7

Browse files
authored
[CI/Build] Add E2E tests for MLPSpeculator (#5791)
Signed-off-by: Thomas Parnell <[email protected]>
1 parent f178e56 commit c2a8ac7

File tree

1 file changed

+216
-0
lines changed

1 file changed

+216
-0
lines changed
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
"""This docstring details important information on the testing methodology.
2+
3+
Most of the tests rely on "greedy equality", where we expect the output of
4+
speculative decoding on a sequence to exactly match the output of normal non-
5+
speculative decoding.
6+
7+
Since speculative decoding with rejection sampling guarantees that the output
8+
distribution matches the target model's output distribution (up to hardware
9+
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
10+
equality.
11+
12+
However, we still need to verify below scenario could be passed:
13+
* Batch size 1 greedy equality
14+
* Batch size >1 greedy equality
15+
* Test greedy equality under preemption
16+
* Test greedy equality under various number of speculative tokens.
17+
18+
With those tests, we can say at least, MLPSpeculator would not break the
19+
correctess for the target model outputs.
20+
"""
21+
22+
import pytest
23+
24+
from .conftest import run_greedy_equality_correctness_test
25+
26+
# main model
27+
MAIN_MODEL = "ibm-granite/granite-3b-code-instruct"
28+
29+
# speculative model
30+
SPEC_MODEL = "ibm-granite/granite-3b-code-instruct-accelerator"
31+
32+
# max. number of speculative tokens: this corresponds to
33+
# n_predict in the config.json of the speculator model.
34+
MAX_SPEC_TOKENS = 5
35+
36+
# precision
37+
PRECISION = "float16"
38+
39+
40+
@pytest.mark.parametrize(
41+
"common_llm_kwargs",
42+
[{
43+
# Skip cuda graph recording for fast test.
44+
"enforce_eager": True,
45+
46+
# Required for spec decode.
47+
"use_v2_block_manager": True,
48+
49+
# Print spec metrics.
50+
"disable_log_stats": False,
51+
52+
# Precision
53+
"dtype": PRECISION,
54+
55+
# Main model
56+
"model": MAIN_MODEL,
57+
}])
58+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
59+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
60+
@pytest.mark.parametrize("test_llm_kwargs", [
61+
{
62+
"speculative_model": SPEC_MODEL,
63+
},
64+
])
65+
@pytest.mark.parametrize("output_len", [
66+
128,
67+
])
68+
@pytest.mark.parametrize("batch_size", [1, 32])
69+
@pytest.mark.parametrize("seed", [1])
70+
def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
71+
batch_size: int, output_len: int):
72+
"""Verify greedy equality with different batch size."""
73+
run_greedy_equality_correctness_test(baseline_llm_generator,
74+
test_llm_generator,
75+
batch_size,
76+
max_output_len=output_len,
77+
force_output_len=True)
78+
79+
80+
@pytest.mark.parametrize(
81+
"common_llm_kwargs",
82+
[{
83+
"block_size": 8,
84+
# 2 for small prompt, 256//8 for generated.
85+
"num_gpu_blocks_override": 2 + 256 // 8,
86+
"max_model_len": (2 + 256 // 8) * 8,
87+
88+
# Skip cuda graph recording for fast test.
89+
"enforce_eager": True,
90+
91+
# Required for spec decode.
92+
"use_v2_block_manager": True,
93+
94+
# Precision
95+
"dtype": PRECISION,
96+
97+
# Main model
98+
"model": MAIN_MODEL,
99+
}])
100+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
101+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
102+
@pytest.mark.parametrize("test_llm_kwargs", [
103+
{
104+
"speculative_model": SPEC_MODEL,
105+
},
106+
])
107+
@pytest.mark.parametrize(
108+
"output_len",
109+
[
110+
# Use small output len for fast test.
111+
128,
112+
])
113+
@pytest.mark.parametrize("batch_size", [4])
114+
@pytest.mark.parametrize("seed", [1])
115+
def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
116+
test_llm_generator,
117+
batch_size: int,
118+
output_len: int):
119+
"""Verify greedy equality, even when some sequences are preempted mid-
120+
generation.
121+
"""
122+
run_greedy_equality_correctness_test(baseline_llm_generator,
123+
test_llm_generator,
124+
batch_size,
125+
max_output_len=output_len,
126+
force_output_len=True)
127+
128+
129+
@pytest.mark.parametrize(
130+
"common_llm_kwargs",
131+
[{
132+
# Skip cuda graph recording for fast test.
133+
"enforce_eager": True,
134+
135+
# Required for spec decode.
136+
"use_v2_block_manager": True,
137+
138+
# Precision
139+
"dtype": PRECISION,
140+
141+
# Main model
142+
"model": MAIN_MODEL,
143+
}])
144+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
145+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
146+
@pytest.mark.parametrize(
147+
"test_llm_kwargs",
148+
[
149+
{
150+
"speculative_model": SPEC_MODEL,
151+
"num_speculative_tokens": k,
152+
}
153+
# Try a range of num. speculative tokens
154+
for k in range(1, 1 + MAX_SPEC_TOKENS)
155+
])
156+
@pytest.mark.parametrize("batch_size", [2])
157+
@pytest.mark.parametrize(
158+
"output_len",
159+
[
160+
# Use smaller output len for fast test.
161+
32,
162+
])
163+
@pytest.mark.parametrize("seed", [1])
164+
def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
165+
batch_size: int, output_len: int):
166+
"""Verify that mlp speculative decoding produces exact equality
167+
to without spec decode with different values of num_speculative_tokens.
168+
"""
169+
run_greedy_equality_correctness_test(baseline_llm_generator,
170+
test_llm_generator,
171+
batch_size,
172+
max_output_len=output_len,
173+
force_output_len=True)
174+
175+
176+
@pytest.mark.parametrize(
177+
"common_llm_kwargs",
178+
[{
179+
# Skip cuda graph recording for fast test.
180+
"enforce_eager": True,
181+
182+
# Required for spec decode.
183+
"use_v2_block_manager": True,
184+
185+
# Precision
186+
"dtype": PRECISION,
187+
188+
# Main model
189+
"model": MAIN_MODEL,
190+
}])
191+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
192+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
193+
@pytest.mark.parametrize("test_llm_kwargs",
194+
[{
195+
"speculative_model": SPEC_MODEL,
196+
"speculative_disable_by_batch_size": 4
197+
}])
198+
@pytest.mark.parametrize("batch_size", [1, 5])
199+
@pytest.mark.parametrize(
200+
"output_len",
201+
[
202+
# Use smaller output len for fast test.
203+
32,
204+
])
205+
@pytest.mark.parametrize("seed", [1])
206+
def test_mlp_disable_queue(baseline_llm_generator, test_llm_generator,
207+
batch_size: int, output_len: int):
208+
"""Verify that mlp speculative decoding produces exact equality
209+
to without spec decode when speculation is disabled for large
210+
batch sizes.
211+
"""
212+
run_greedy_equality_correctness_test(baseline_llm_generator,
213+
test_llm_generator,
214+
batch_size,
215+
max_output_len=output_len,
216+
force_output_len=True)

0 commit comments

Comments
 (0)