Skip to content

Commit 1fc853f

Browse files
committed
integration-vllm-test
stack-info: PR: #2258, branch: drisspg/stack/58
1 parent 1017c7e commit 1fc853f

File tree

1 file changed

+283
-0
lines changed

1 file changed

+283
-0
lines changed

test/integration/test_vllm.py

Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import importlib.util
8+
import os
9+
import random
10+
from pathlib import Path
11+
from typing import List, Tuple
12+
13+
import numpy as np
14+
import pytest
15+
import torch
16+
17+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7
18+
19+
if not TORCH_VERSION_AT_LEAST_2_7:
20+
pytest.skip("Requires PyTorch 2.7 or higher", allow_module_level=True)
21+
22+
23+
VLLM_AVAILABLE = importlib.util.find_spec("vllm") is not None
24+
TRANSFORMERS_AVAILABLE = importlib.util.find_spec("transformers") is not None
25+
26+
if not VLLM_AVAILABLE:
27+
pytest.skip("vLLM not installed", allow_module_level=True)
28+
29+
if not TRANSFORMERS_AVAILABLE:
30+
pytest.skip("transformers not installed", allow_module_level=True)
31+
32+
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
33+
from vllm import LLM, SamplingParams
34+
35+
from torchao.prototype.mx_formats import MXGemmKernelChoice
36+
from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig
37+
from torchao.quantization.granularity import PerRow, PerTensor
38+
from torchao.quantization.quant_api import (
39+
CutlassInt4PackedLayout,
40+
Float8DynamicActivationFloat8WeightConfig,
41+
GemliteUIntXWeightOnlyConfig,
42+
Int4DynamicActivationInt4WeightConfig,
43+
Int4WeightOnlyConfig,
44+
Int8DynamicActivationInt4WeightConfig,
45+
Int8DynamicActivationInt8WeightConfig,
46+
Int8WeightOnlyConfig,
47+
)
48+
49+
50+
def get_tests() -> List[Tuple[str, str]]:
51+
"""Get all the tests based off of device info"""
52+
53+
BASE_TESTS = [("int8_weight_only", "per_tensor")]
54+
SM89_TESTS = [("fp8", "per_tensor"), ("fp8", "per_row")]
55+
SM90_ONLY_TESTS = [("A8W4", "per_tensor")]
56+
SM100_TESTS = [
57+
("mxfp8", "per_tensor")
58+
] # Failing for : https://github.com/pytorch/ao/issues/2239
59+
60+
# Check CUDA availability first
61+
if not torch.cuda.is_available():
62+
return [] # No CUDA, no tests
63+
64+
major, minor = torch.cuda.get_device_capability()
65+
66+
# Build test list based on compute capability
67+
all_tests = []
68+
69+
# Always include base tests if we have CUDA
70+
all_tests.extend(BASE_TESTS)
71+
72+
# Add SM89+ tests
73+
if major > 8 or (major == 8 and minor >= 9):
74+
all_tests.extend(SM89_TESTS)
75+
76+
# Add SM100+ tests
77+
if major >= 10:
78+
all_tests.extend(SM100_TESTS)
79+
80+
# Only work for sm 90
81+
if major == 9:
82+
all_tests.extend(SM90_ONLY_TESTS)
83+
84+
return all_tests
85+
86+
87+
class TestVLLMIntegration:
88+
"""Integration tests for vLLM with quantized models."""
89+
90+
@classmethod
91+
def setup_class(cls):
92+
"""Set up test environment."""
93+
# Set seeds for reproducibility
94+
cls.set_seed(42)
95+
96+
# Set vLLM environment variables
97+
os.environ["VLLM_USE_V1"] = "1"
98+
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
99+
os.environ["VLLM_TEST_STANDALONE_COMPILE"] = "1"
100+
101+
@classmethod
102+
def teardown_class(cls):
103+
"""Clean up after all tests."""
104+
torch.cuda.empty_cache()
105+
import gc
106+
107+
gc.collect()
108+
109+
def setup_method(self, method):
110+
"""Clean up before each test method."""
111+
torch.cuda.empty_cache()
112+
import gc
113+
114+
gc.collect()
115+
116+
def teardown_method(self, method):
117+
"""Clean up after each test method."""
118+
torch.cuda.empty_cache()
119+
import gc
120+
121+
gc.collect()
122+
123+
@staticmethod
124+
def set_seed(seed):
125+
"""Set random seeds for reproducibility."""
126+
random.seed(seed)
127+
np.random.seed(seed)
128+
torch.manual_seed(seed)
129+
torch.cuda.manual_seed_all(seed)
130+
131+
def get_quantization_config(self, quant_type: str, granularity: str = "per_tensor"):
132+
"""Create TorchAo quantization config based on provided parameters."""
133+
granularity_mapping = {
134+
"per_row": PerRow(),
135+
"per_tensor": PerTensor(),
136+
}
137+
138+
gran = granularity_mapping[granularity]
139+
140+
if quant_type == "autoquant":
141+
return TorchAoConfig("autoquant", min_sqnr=40.0)
142+
elif quant_type == "fp8":
143+
return TorchAoConfig(
144+
Float8DynamicActivationFloat8WeightConfig(granularity=gran)
145+
)
146+
elif quant_type == "int4_weight_only":
147+
return TorchAoConfig(Int4WeightOnlyConfig(group_size=128))
148+
elif quant_type == "int8_weight_only":
149+
return TorchAoConfig(Int8WeightOnlyConfig())
150+
elif quant_type == "int8_dynamic_act_int8_weight":
151+
return TorchAoConfig(Int8DynamicActivationInt8WeightConfig())
152+
elif quant_type == "gemlite":
153+
return TorchAoConfig(GemliteUIntXWeightOnlyConfig())
154+
elif quant_type == "A4W4":
155+
return TorchAoConfig(Int4DynamicActivationInt4WeightConfig())
156+
elif quant_type == "A8W4":
157+
return TorchAoConfig(
158+
Int8DynamicActivationInt4WeightConfig(layout=CutlassInt4PackedLayout())
159+
)
160+
elif quant_type == "mxfp8":
161+
return TorchAoConfig(MXFPInferenceConfig())
162+
elif quant_type == "mxfp4":
163+
return TorchAoConfig(
164+
MXFPInferenceConfig(
165+
activation_dtype=torch.float4_e2m1fn_x2,
166+
weight_dtype=torch.float4_e2m1fn_x2,
167+
block_size=32,
168+
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS,
169+
)
170+
)
171+
else:
172+
raise ValueError(f"Unsupported quantization type: {quant_type}")
173+
174+
def quantize_and_save_model(
175+
self,
176+
model_name: str,
177+
quant_type: str,
178+
output_dir: Path,
179+
granularity: str = "per_tensor",
180+
):
181+
"""Quantize a model and save it to disk."""
182+
# Get quantization config
183+
quantization_config = self.get_quantization_config(quant_type, granularity)
184+
185+
# Load and quantize model
186+
print(f"Loading and quantizing model with {quant_type}...")
187+
quantized_model = AutoModelForCausalLM.from_pretrained(
188+
model_name,
189+
torch_dtype="bfloat16",
190+
device_map="cuda",
191+
quantization_config=quantization_config,
192+
)
193+
194+
# Load tokenizer
195+
tokenizer = AutoTokenizer.from_pretrained(model_name)
196+
197+
# Quick test generation to verify model works
198+
test_input = "Hello, world!"
199+
input_ids = tokenizer(test_input, return_tensors="pt").to(
200+
quantized_model.device
201+
)
202+
203+
with torch.no_grad():
204+
output = quantized_model.generate(**input_ids, max_new_tokens=5)
205+
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
206+
print(f"Quick test - Input: {test_input}, Output: {decoded}")
207+
208+
# Save quantized model
209+
print(f"Saving quantized model to {output_dir}...")
210+
quantized_model.save_pretrained(output_dir, safe_serialization=False)
211+
tokenizer.save_pretrained(output_dir)
212+
213+
# Clean up to free memory
214+
del quantized_model
215+
torch.cuda.empty_cache()
216+
217+
return output_dir
218+
219+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
220+
@pytest.mark.skipif(not VLLM_AVAILABLE, reason="vLLM not installed")
221+
@pytest.mark.parametrize("quant_type,granularity", get_tests())
222+
@pytest.mark.parametrize("compile", [True, False])
223+
@pytest.mark.parametrize(
224+
"tp_size", [1, 2] if torch.cuda.device_count() > 1 else [1]
225+
)
226+
def test_vllm_smoke_test(self, tmp_path, quant_type, granularity, compile, tp_size):
227+
"""Test vLLM generation with quantized models."""
228+
# Skip per_row tests if not supported
229+
torch._dynamo.reset()
230+
if granularity == "per_row" and not torch.cuda.get_device_capability()[0] >= 9:
231+
pytest.skip("Per-row quantization requires SM90+")
232+
233+
# Use a small model for testing
234+
base_model = "facebook/opt-125m"
235+
236+
# Quantize the model
237+
output_dir = tmp_path / f"{quant_type}-{granularity}-opt-125m"
238+
quantized_model_path = self.quantize_and_save_model(
239+
base_model, quant_type, output_dir, granularity
240+
)
241+
242+
# Test generation with vLLM
243+
sampling_params = SamplingParams(
244+
temperature=0.8,
245+
top_p=0.95,
246+
seed=42,
247+
max_tokens=16, # Small for testing
248+
)
249+
250+
# Create LLM instance
251+
llm = LLM(
252+
model=str(quantized_model_path),
253+
tensor_parallel_size=tp_size,
254+
enforce_eager=not compile,
255+
dtype="bfloat16",
256+
num_gpu_blocks_override=128,
257+
)
258+
259+
# Test prompts
260+
prompts = [
261+
"Hello, my name is",
262+
"The capital of France is",
263+
]
264+
265+
# Generate outputs
266+
outputs = llm.generate(prompts, sampling_params)
267+
268+
# Verify outputs
269+
assert len(outputs) == len(prompts)
270+
for output in outputs:
271+
assert output.prompt in prompts
272+
assert len(output.outputs) > 0
273+
generated_text = output.outputs[0].text
274+
assert isinstance(generated_text, str)
275+
assert len(generated_text) > 0
276+
277+
# Clean up
278+
del llm
279+
torch.cuda.empty_cache()
280+
281+
282+
if __name__ == "__main__":
283+
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)