Skip to content

Commit f27fa03

Browse files
committed
integration-vllm-test
1 parent a07c9e2 commit f27fa03

File tree

1 file changed

+233
-0
lines changed

1 file changed

+233
-0
lines changed

test/integration/test_vllm.py

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import importlib.util
8+
import os
9+
import random
10+
from pathlib import Path
11+
12+
import numpy as np
13+
import pytest
14+
import torch
15+
16+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7
17+
18+
if not TORCH_VERSION_AT_LEAST_2_7:
19+
pytest.skip("Requires PyTorch 2.7 or higher", allow_module_level=True)
20+
21+
22+
VLLM_AVAILABLE = importlib.util.find_spec("vllm") is not None
23+
TRANSFORMERS_AVAILABLE = importlib.util.find_spec("transformers") is not None
24+
25+
if not VLLM_AVAILABLE:
26+
pytest.skip("vLLM not installed", allow_module_level=True)
27+
28+
if not TRANSFORMERS_AVAILABLE:
29+
pytest.skip("transformers not installed", allow_module_level=True)
30+
31+
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
32+
from vllm import LLM, SamplingParams
33+
34+
from torchao.prototype.mx_formats import MXGemmKernelChoice
35+
from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig
36+
from torchao.quantization.granularity import PerRow, PerTensor
37+
from torchao.quantization.quant_api import (
38+
CutlassInt4PackedLayout,
39+
Float8DynamicActivationFloat8WeightConfig,
40+
GemliteUIntXWeightOnlyConfig,
41+
Int4DynamicActivationInt4WeightConfig,
42+
Int4WeightOnlyConfig,
43+
Int8DynamicActivationInt4WeightConfig,
44+
Int8DynamicActivationInt8WeightConfig,
45+
Int8WeightOnlyConfig,
46+
)
47+
48+
49+
class TestVLLMIntegration:
50+
"""Integration tests for vLLM with quantized models."""
51+
52+
@classmethod
53+
def setup_class(cls):
54+
"""Set up test environment."""
55+
# Set seeds for reproducibility
56+
cls.set_seed(42)
57+
58+
# Set vLLM environment variables
59+
os.environ["VLLM_USE_V1"] = "1"
60+
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
61+
os.environ["VLLM_TEST_STANDALONE_COMPILE"] = "1"
62+
63+
@staticmethod
64+
def set_seed(seed):
65+
"""Set random seeds for reproducibility."""
66+
random.seed(seed)
67+
np.random.seed(seed)
68+
torch.manual_seed(seed)
69+
torch.cuda.manual_seed_all(seed)
70+
71+
def get_quantization_config(self, quant_type: str, granularity: str = "per_tensor"):
72+
"""Create TorchAo quantization config based on provided parameters."""
73+
granularity_mapping = {
74+
"per_row": PerRow(),
75+
"per_tensor": PerTensor(),
76+
}
77+
78+
gran = granularity_mapping[granularity]
79+
80+
match quant_type:
81+
case "autoquant":
82+
return TorchAoConfig("autoquant", min_sqnr=40.0)
83+
case "fp8":
84+
return TorchAoConfig(
85+
Float8DynamicActivationFloat8WeightConfig(granularity=gran)
86+
)
87+
case "int4_weight_only":
88+
return TorchAoConfig(Int4WeightOnlyConfig(group_size=128))
89+
case "int8_weight_only":
90+
return TorchAoConfig(Int8WeightOnlyConfig())
91+
case "int8_dynamic_act_int8_weight":
92+
return TorchAoConfig(Int8DynamicActivationInt8WeightConfig())
93+
case "gemlite":
94+
return TorchAoConfig(GemliteUIntXWeightOnlyConfig())
95+
case "A4W4":
96+
return TorchAoConfig(Int4DynamicActivationInt4WeightConfig())
97+
case "A8W4":
98+
return TorchAoConfig(
99+
Int8DynamicActivationInt4WeightConfig(
100+
layout=CutlassInt4PackedLayout()
101+
)
102+
)
103+
case "mxfp8":
104+
return TorchAoConfig(MXFPInferenceConfig())
105+
case "mxfp4":
106+
return TorchAoConfig(
107+
MXFPInferenceConfig(
108+
activation_dtype=torch.float4_e2m1fn_x2,
109+
weight_dtype=torch.float4_e2m1fn_x2,
110+
block_size=32,
111+
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS,
112+
)
113+
)
114+
case _:
115+
raise ValueError(f"Unsupported quantization type: {quant_type}")
116+
117+
def quantize_and_save_model(
118+
self,
119+
model_name: str,
120+
quant_type: str,
121+
output_dir: Path,
122+
granularity: str = "per_tensor",
123+
):
124+
"""Quantize a model and save it to disk."""
125+
# Get quantization config
126+
quantization_config = self.get_quantization_config(quant_type, granularity)
127+
128+
# Load and quantize model
129+
print(f"Loading and quantizing model with {quant_type}...")
130+
quantized_model = AutoModelForCausalLM.from_pretrained(
131+
model_name,
132+
torch_dtype="bfloat16",
133+
device_map="cuda",
134+
quantization_config=quantization_config,
135+
)
136+
137+
# Load tokenizer
138+
tokenizer = AutoTokenizer.from_pretrained(model_name)
139+
140+
# Quick test generation to verify model works
141+
test_input = "Hello, world!"
142+
input_ids = tokenizer(test_input, return_tensors="pt").to(
143+
quantized_model.device
144+
)
145+
146+
with torch.no_grad():
147+
output = quantized_model.generate(**input_ids, max_new_tokens=5)
148+
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
149+
print(f"Quick test - Input: {test_input}, Output: {decoded}")
150+
151+
# Save quantized model
152+
print(f"Saving quantized model to {output_dir}...")
153+
quantized_model.save_pretrained(output_dir, safe_serialization=False)
154+
tokenizer.save_pretrained(output_dir)
155+
156+
# Clean up to free memory
157+
del quantized_model
158+
torch.cuda.empty_cache()
159+
160+
return output_dir
161+
162+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
163+
@pytest.mark.skipif(not VLLM_AVAILABLE, reason="vLLM not installed")
164+
@pytest.mark.parametrize(
165+
"quant_type,granularity",
166+
[
167+
("fp8", "per_tensor"),
168+
("fp8", "per_row"),
169+
("int8_weight_only", "per_tensor"),
170+
("int4_weight_only", "per_tensor"),
171+
# ("A8W4", "per_tensor"),
172+
],
173+
)
174+
@pytest.mark.parametrize("compile", [True, False])
175+
@pytest.mark.parametrize(
176+
"tp_size", [1, 2] if torch.cuda.device_count() >= 2 else [1]
177+
)
178+
def test_vllm_smoke_test(self, tmp_path, quant_type, granularity, compile, tp_size):
179+
"""Test vLLM generation with quantized models."""
180+
# Skip per_row tests if not supported
181+
if granularity == "per_row" and not torch.cuda.get_device_capability()[0] >= 9:
182+
pytest.skip("Per-row quantization requires SM90+")
183+
184+
# Use a small model for testing
185+
base_model = "facebook/opt-125m"
186+
187+
# Quantize the model
188+
output_dir = tmp_path / f"{quant_type}-{granularity}-opt-125m"
189+
quantized_model_path = self.quantize_and_save_model(
190+
base_model, quant_type, output_dir, granularity
191+
)
192+
193+
# Test generation with vLLM
194+
sampling_params = SamplingParams(
195+
temperature=0.8,
196+
top_p=0.95,
197+
seed=42,
198+
max_tokens=16, # Small for testing
199+
)
200+
201+
# Create LLM instance
202+
llm = LLM(
203+
model=str(quantized_model_path),
204+
tensor_parallel_size=tp_size,
205+
enforce_eager=not compile,
206+
dtype="bfloat16", # Match quantization dtype
207+
)
208+
209+
# Test prompts
210+
prompts = [
211+
"Hello, my name is",
212+
"The capital of France is",
213+
]
214+
215+
# Generate outputs
216+
outputs = llm.generate(prompts, sampling_params)
217+
218+
# Verify outputs
219+
assert len(outputs) == len(prompts)
220+
for output in outputs:
221+
assert output.prompt in prompts
222+
assert len(output.outputs) > 0
223+
generated_text = output.outputs[0].text
224+
assert isinstance(generated_text, str)
225+
assert len(generated_text) > 0
226+
227+
# Clean up
228+
del llm
229+
torch.cuda.empty_cache()
230+
231+
232+
if __name__ == "__main__":
233+
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)