Skip to content

Commit c06d97f

Browse files
committed
integration-vllm-test
stack-info: PR: #2258, branch: drisspg/stack/58
1 parent 1017c7e commit c06d97f

File tree

1 file changed

+257
-0
lines changed

1 file changed

+257
-0
lines changed

test/integration/test_vllm.py

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import importlib.util
8+
import os
9+
import random
10+
from pathlib import Path
11+
12+
import numpy as np
13+
import pytest
14+
import torch
15+
16+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7
17+
18+
if not TORCH_VERSION_AT_LEAST_2_7:
19+
pytest.skip("Requires PyTorch 2.7 or higher", allow_module_level=True)
20+
21+
22+
VLLM_AVAILABLE = importlib.util.find_spec("vllm") is not None
23+
TRANSFORMERS_AVAILABLE = importlib.util.find_spec("transformers") is not None
24+
25+
if not VLLM_AVAILABLE:
26+
pytest.skip("vLLM not installed", allow_module_level=True)
27+
28+
if not TRANSFORMERS_AVAILABLE:
29+
pytest.skip("transformers not installed", allow_module_level=True)
30+
31+
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
32+
from vllm import LLM, SamplingParams
33+
34+
from torchao.prototype.mx_formats import MXGemmKernelChoice
35+
from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig
36+
from torchao.quantization.granularity import PerRow, PerTensor
37+
from torchao.quantization.quant_api import (
38+
CutlassInt4PackedLayout,
39+
Float8DynamicActivationFloat8WeightConfig,
40+
GemliteUIntXWeightOnlyConfig,
41+
Int4DynamicActivationInt4WeightConfig,
42+
Int4WeightOnlyConfig,
43+
Int8DynamicActivationInt4WeightConfig,
44+
Int8DynamicActivationInt8WeightConfig,
45+
Int8WeightOnlyConfig,
46+
)
47+
48+
49+
class TestVLLMIntegration:
50+
"""Integration tests for vLLM with quantized models."""
51+
52+
@classmethod
53+
def setup_class(cls):
54+
"""Set up test environment."""
55+
# Set seeds for reproducibility
56+
cls.set_seed(42)
57+
58+
# Set vLLM environment variables
59+
os.environ["VLLM_USE_V1"] = "1"
60+
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
61+
os.environ["VLLM_TEST_STANDALONE_COMPILE"] = "1"
62+
63+
@classmethod
64+
def teardown_class(cls):
65+
"""Clean up after all tests."""
66+
torch.cuda.empty_cache()
67+
import gc
68+
69+
gc.collect()
70+
71+
def setup_method(self, method):
72+
"""Clean up before each test method."""
73+
torch.cuda.empty_cache()
74+
import gc
75+
76+
gc.collect()
77+
78+
def teardown_method(self, method):
79+
"""Clean up after each test method."""
80+
torch.cuda.empty_cache()
81+
import gc
82+
83+
gc.collect()
84+
85+
@staticmethod
86+
def set_seed(seed):
87+
"""Set random seeds for reproducibility."""
88+
random.seed(seed)
89+
np.random.seed(seed)
90+
torch.manual_seed(seed)
91+
torch.cuda.manual_seed_all(seed)
92+
93+
def get_quantization_config(self, quant_type: str, granularity: str = "per_tensor"):
94+
"""Create TorchAo quantization config based on provided parameters."""
95+
granularity_mapping = {
96+
"per_row": PerRow(),
97+
"per_tensor": PerTensor(),
98+
}
99+
100+
gran = granularity_mapping[granularity]
101+
102+
match quant_type:
103+
case "autoquant":
104+
return TorchAoConfig("autoquant", min_sqnr=40.0)
105+
case "fp8":
106+
return TorchAoConfig(
107+
Float8DynamicActivationFloat8WeightConfig(granularity=gran)
108+
)
109+
case "int4_weight_only":
110+
return TorchAoConfig(Int4WeightOnlyConfig(group_size=128))
111+
case "int8_weight_only":
112+
return TorchAoConfig(Int8WeightOnlyConfig())
113+
case "int8_dynamic_act_int8_weight":
114+
return TorchAoConfig(Int8DynamicActivationInt8WeightConfig())
115+
case "gemlite":
116+
return TorchAoConfig(GemliteUIntXWeightOnlyConfig())
117+
case "A4W4":
118+
return TorchAoConfig(Int4DynamicActivationInt4WeightConfig())
119+
case "A8W4":
120+
return TorchAoConfig(
121+
Int8DynamicActivationInt4WeightConfig(
122+
layout=CutlassInt4PackedLayout()
123+
)
124+
)
125+
case "mxfp8":
126+
return TorchAoConfig(MXFPInferenceConfig())
127+
case "mxfp4":
128+
return TorchAoConfig(
129+
MXFPInferenceConfig(
130+
activation_dtype=torch.float4_e2m1fn_x2,
131+
weight_dtype=torch.float4_e2m1fn_x2,
132+
block_size=32,
133+
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS,
134+
)
135+
)
136+
case _:
137+
raise ValueError(f"Unsupported quantization type: {quant_type}")
138+
139+
def quantize_and_save_model(
140+
self,
141+
model_name: str,
142+
quant_type: str,
143+
output_dir: Path,
144+
granularity: str = "per_tensor",
145+
):
146+
"""Quantize a model and save it to disk."""
147+
# Get quantization config
148+
quantization_config = self.get_quantization_config(quant_type, granularity)
149+
150+
# Load and quantize model
151+
print(f"Loading and quantizing model with {quant_type}...")
152+
quantized_model = AutoModelForCausalLM.from_pretrained(
153+
model_name,
154+
torch_dtype="bfloat16",
155+
device_map="cuda",
156+
quantization_config=quantization_config,
157+
)
158+
159+
# Load tokenizer
160+
tokenizer = AutoTokenizer.from_pretrained(model_name)
161+
162+
# Quick test generation to verify model works
163+
test_input = "Hello, world!"
164+
input_ids = tokenizer(test_input, return_tensors="pt").to(
165+
quantized_model.device
166+
)
167+
168+
with torch.no_grad():
169+
output = quantized_model.generate(**input_ids, max_new_tokens=5)
170+
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
171+
print(f"Quick test - Input: {test_input}, Output: {decoded}")
172+
173+
# Save quantized model
174+
print(f"Saving quantized model to {output_dir}...")
175+
quantized_model.save_pretrained(output_dir, safe_serialization=False)
176+
tokenizer.save_pretrained(output_dir)
177+
178+
# Clean up to free memory
179+
del quantized_model
180+
torch.cuda.empty_cache()
181+
182+
return output_dir
183+
184+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
185+
@pytest.mark.skipif(not VLLM_AVAILABLE, reason="vLLM not installed")
186+
@pytest.mark.parametrize(
187+
"quant_type,granularity",
188+
[
189+
("fp8", "per_tensor"),
190+
("fp8", "per_row"),
191+
("int8_weight_only", "per_tensor"),
192+
("int4_weight_only", "per_tensor"),
193+
("A8W4", "per_tensor"),
194+
],
195+
)
196+
@pytest.mark.parametrize("compile", [True, False])
197+
@pytest.mark.parametrize(
198+
"tp_size", [1, 2] if torch.cuda.device_count() > 1 else [1]
199+
)
200+
def test_vllm_smoke_test(self, tmp_path, quant_type, granularity, compile, tp_size):
201+
"""Test vLLM generation with quantized models."""
202+
# Skip per_row tests if not supported
203+
torch._dynamo.reset()
204+
if granularity == "per_row" and not torch.cuda.get_device_capability()[0] >= 9:
205+
pytest.skip("Per-row quantization requires SM90+")
206+
207+
# Use a small model for testing
208+
base_model = "facebook/opt-125m"
209+
210+
# Quantize the model
211+
output_dir = tmp_path / f"{quant_type}-{granularity}-opt-125m"
212+
quantized_model_path = self.quantize_and_save_model(
213+
base_model, quant_type, output_dir, granularity
214+
)
215+
216+
# Test generation with vLLM
217+
sampling_params = SamplingParams(
218+
temperature=0.8,
219+
top_p=0.95,
220+
seed=42,
221+
max_tokens=16, # Small for testing
222+
)
223+
224+
# Create LLM instance
225+
llm = LLM(
226+
model=str(quantized_model_path),
227+
tensor_parallel_size=tp_size,
228+
enforce_eager=not compile,
229+
dtype="bfloat16",
230+
num_gpu_blocks_override=128,
231+
)
232+
233+
# Test prompts
234+
prompts = [
235+
"Hello, my name is",
236+
"The capital of France is",
237+
]
238+
239+
# Generate outputs
240+
outputs = llm.generate(prompts, sampling_params)
241+
242+
# Verify outputs
243+
assert len(outputs) == len(prompts)
244+
for output in outputs:
245+
assert output.prompt in prompts
246+
assert len(output.outputs) > 0
247+
generated_text = output.outputs[0].text
248+
assert isinstance(generated_text, str)
249+
assert len(generated_text) > 0
250+
251+
# Clean up
252+
del llm
253+
torch.cuda.empty_cache()
254+
255+
256+
if __name__ == "__main__":
257+
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)