Skip to content

Commit edd42f9

Browse files
committed
integration-vllm-test
stack-info: PR: #2258, branch: drisspg/stack/58
1 parent 1017c7e commit edd42f9

File tree

1 file changed

+254
-0
lines changed

1 file changed

+254
-0
lines changed

test/integration/test_vllm.py

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import importlib.util
8+
import os
9+
import random
10+
from pathlib import Path
11+
12+
import numpy as np
13+
import pytest
14+
import torch
15+
16+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7
17+
18+
if not TORCH_VERSION_AT_LEAST_2_7:
19+
pytest.skip("Requires PyTorch 2.7 or higher", allow_module_level=True)
20+
21+
22+
VLLM_AVAILABLE = importlib.util.find_spec("vllm") is not None
23+
TRANSFORMERS_AVAILABLE = importlib.util.find_spec("transformers") is not None
24+
25+
if not VLLM_AVAILABLE:
26+
pytest.skip("vLLM not installed", allow_module_level=True)
27+
28+
if not TRANSFORMERS_AVAILABLE:
29+
pytest.skip("transformers not installed", allow_module_level=True)
30+
31+
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
32+
from vllm import LLM, SamplingParams
33+
34+
from torchao.prototype.mx_formats import MXGemmKernelChoice
35+
from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig
36+
from torchao.quantization.granularity import PerRow, PerTensor
37+
from torchao.quantization.quant_api import (
38+
CutlassInt4PackedLayout,
39+
Float8DynamicActivationFloat8WeightConfig,
40+
GemliteUIntXWeightOnlyConfig,
41+
Int4DynamicActivationInt4WeightConfig,
42+
Int4WeightOnlyConfig,
43+
Int8DynamicActivationInt4WeightConfig,
44+
Int8DynamicActivationInt8WeightConfig,
45+
Int8WeightOnlyConfig,
46+
)
47+
48+
49+
class TestVLLMIntegration:
50+
"""Integration tests for vLLM with quantized models."""
51+
52+
@classmethod
53+
def setup_class(cls):
54+
"""Set up test environment."""
55+
# Set seeds for reproducibility
56+
cls.set_seed(42)
57+
58+
# Set vLLM environment variables
59+
os.environ["VLLM_USE_V1"] = "1"
60+
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
61+
os.environ["VLLM_TEST_STANDALONE_COMPILE"] = "1"
62+
63+
@classmethod
64+
def teardown_class(cls):
65+
"""Clean up after all tests."""
66+
torch.cuda.empty_cache()
67+
import gc
68+
69+
gc.collect()
70+
71+
def setup_method(self, method):
72+
"""Clean up before each test method."""
73+
torch.cuda.empty_cache()
74+
import gc
75+
76+
gc.collect()
77+
78+
def teardown_method(self, method):
79+
"""Clean up after each test method."""
80+
torch.cuda.empty_cache()
81+
import gc
82+
83+
gc.collect()
84+
85+
@staticmethod
86+
def set_seed(seed):
87+
"""Set random seeds for reproducibility."""
88+
random.seed(seed)
89+
np.random.seed(seed)
90+
torch.manual_seed(seed)
91+
torch.cuda.manual_seed_all(seed)
92+
93+
def get_quantization_config(self, quant_type: str, granularity: str = "per_tensor"):
94+
"""Create TorchAo quantization config based on provided parameters."""
95+
granularity_mapping = {
96+
"per_row": PerRow(),
97+
"per_tensor": PerTensor(),
98+
}
99+
100+
gran = granularity_mapping[granularity]
101+
102+
if quant_type == "autoquant":
103+
return TorchAoConfig("autoquant", min_sqnr=40.0)
104+
elif quant_type == "fp8":
105+
return TorchAoConfig(
106+
Float8DynamicActivationFloat8WeightConfig(granularity=gran)
107+
)
108+
elif quant_type == "int4_weight_only":
109+
return TorchAoConfig(Int4WeightOnlyConfig(group_size=128))
110+
elif quant_type == "int8_weight_only":
111+
return TorchAoConfig(Int8WeightOnlyConfig())
112+
elif quant_type == "int8_dynamic_act_int8_weight":
113+
return TorchAoConfig(Int8DynamicActivationInt8WeightConfig())
114+
elif quant_type == "gemlite":
115+
return TorchAoConfig(GemliteUIntXWeightOnlyConfig())
116+
elif quant_type == "A4W4":
117+
return TorchAoConfig(Int4DynamicActivationInt4WeightConfig())
118+
elif quant_type == "A8W4":
119+
return TorchAoConfig(
120+
Int8DynamicActivationInt4WeightConfig(layout=CutlassInt4PackedLayout())
121+
)
122+
elif quant_type == "mxfp8":
123+
return TorchAoConfig(MXFPInferenceConfig())
124+
elif quant_type == "mxfp4":
125+
return TorchAoConfig(
126+
MXFPInferenceConfig(
127+
activation_dtype=torch.float4_e2m1fn_x2,
128+
weight_dtype=torch.float4_e2m1fn_x2,
129+
block_size=32,
130+
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS,
131+
)
132+
)
133+
else:
134+
raise ValueError(f"Unsupported quantization type: {quant_type}")
135+
136+
def quantize_and_save_model(
137+
self,
138+
model_name: str,
139+
quant_type: str,
140+
output_dir: Path,
141+
granularity: str = "per_tensor",
142+
):
143+
"""Quantize a model and save it to disk."""
144+
# Get quantization config
145+
quantization_config = self.get_quantization_config(quant_type, granularity)
146+
147+
# Load and quantize model
148+
print(f"Loading and quantizing model with {quant_type}...")
149+
quantized_model = AutoModelForCausalLM.from_pretrained(
150+
model_name,
151+
torch_dtype="bfloat16",
152+
device_map="cuda",
153+
quantization_config=quantization_config,
154+
)
155+
156+
# Load tokenizer
157+
tokenizer = AutoTokenizer.from_pretrained(model_name)
158+
159+
# Quick test generation to verify model works
160+
test_input = "Hello, world!"
161+
input_ids = tokenizer(test_input, return_tensors="pt").to(
162+
quantized_model.device
163+
)
164+
165+
with torch.no_grad():
166+
output = quantized_model.generate(**input_ids, max_new_tokens=5)
167+
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
168+
print(f"Quick test - Input: {test_input}, Output: {decoded}")
169+
170+
# Save quantized model
171+
print(f"Saving quantized model to {output_dir}...")
172+
quantized_model.save_pretrained(output_dir, safe_serialization=False)
173+
tokenizer.save_pretrained(output_dir)
174+
175+
# Clean up to free memory
176+
del quantized_model
177+
torch.cuda.empty_cache()
178+
179+
return output_dir
180+
181+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
182+
@pytest.mark.skipif(not VLLM_AVAILABLE, reason="vLLM not installed")
183+
@pytest.mark.parametrize(
184+
"quant_type,granularity",
185+
[
186+
("fp8", "per_tensor"),
187+
("fp8", "per_row"),
188+
("int8_weight_only", "per_tensor"),
189+
("int4_weight_only", "per_tensor"),
190+
# ("A8W4", "per_tensor"), # Not supported on CPU backend
191+
],
192+
)
193+
@pytest.mark.parametrize("compile", [True, False])
194+
@pytest.mark.parametrize(
195+
"tp_size", [1, 2] if torch.cuda.device_count() > 1 else [1]
196+
)
197+
def test_vllm_smoke_test(self, tmp_path, quant_type, granularity, compile, tp_size):
198+
"""Test vLLM generation with quantized models."""
199+
# Skip per_row tests if not supported
200+
torch._dynamo.reset()
201+
if granularity == "per_row" and not torch.cuda.get_device_capability()[0] >= 9:
202+
pytest.skip("Per-row quantization requires SM90+")
203+
204+
# Use a small model for testing
205+
base_model = "facebook/opt-125m"
206+
207+
# Quantize the model
208+
output_dir = tmp_path / f"{quant_type}-{granularity}-opt-125m"
209+
quantized_model_path = self.quantize_and_save_model(
210+
base_model, quant_type, output_dir, granularity
211+
)
212+
213+
# Test generation with vLLM
214+
sampling_params = SamplingParams(
215+
temperature=0.8,
216+
top_p=0.95,
217+
seed=42,
218+
max_tokens=16, # Small for testing
219+
)
220+
221+
# Create LLM instance
222+
llm = LLM(
223+
model=str(quantized_model_path),
224+
tensor_parallel_size=tp_size,
225+
enforce_eager=not compile,
226+
dtype="bfloat16",
227+
num_gpu_blocks_override=128,
228+
)
229+
230+
# Test prompts
231+
prompts = [
232+
"Hello, my name is",
233+
"The capital of France is",
234+
]
235+
236+
# Generate outputs
237+
outputs = llm.generate(prompts, sampling_params)
238+
239+
# Verify outputs
240+
assert len(outputs) == len(prompts)
241+
for output in outputs:
242+
assert output.prompt in prompts
243+
assert len(output.outputs) > 0
244+
generated_text = output.outputs[0].text
245+
assert isinstance(generated_text, str)
246+
assert len(generated_text) > 0
247+
248+
# Clean up
249+
del llm
250+
torch.cuda.empty_cache()
251+
252+
253+
if __name__ == "__main__":
254+
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)