Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
1b46297
Add some fullgraph tests
bnellnm Sep 12, 2024
a51e4b2
fix
bnellnm Sep 12, 2024
20bbea1
add some tests
bnellnm Sep 13, 2024
01fc6c4
wip
bnellnm Sep 13, 2024
9bf3b76
wip
bnellnm Sep 13, 2024
b20d1d6
more tests
bnellnm Sep 13, 2024
9cec973
wip
bnellnm Sep 13, 2024
fe894f3
support fp8 opchecks
bnellnm Sep 13, 2024
aa3f3ce
more tests
bnellnm Sep 13, 2024
1e8acf3
more tests
bnellnm Sep 13, 2024
ca6dfc2
add rotary embedding tests
bnellnm Sep 14, 2024
a9b91a5
fix moe stuff
bnellnm Sep 14, 2024
cfdf337
ggml dequant opcheck test
bnellnm Sep 14, 2024
fe6a49d
add ggml_mul_mat opcheck tests
bnellnm Sep 15, 2024
0878421
aqlm opcheck tests
bnellnm Sep 15, 2024
8a0e610
more opcheck tests
bnellnm Sep 16, 2024
707d189
marlin_gemm opcheck test
bnellnm Sep 16, 2024
8a74b87
cleanup cruft
bnellnm Sep 16, 2024
f458c6e
cleanups
bnellnm Sep 16, 2024
cfce968
try to stop OOM errors
bnellnm Sep 17, 2024
5411e3d
try to fix bugs
bnellnm Sep 17, 2024
3b6bbdc
rebase + merge
bnellnm Sep 19, 2024
786e8c1
fix moe test
bnellnm Sep 19, 2024
594281e
fix marlin_gemm_fake
bnellnm Sep 19, 2024
e1f6c89
fix causal conv meta func + opcheck test
bnellnm Sep 19, 2024
2bf161d
format
bnellnm Sep 19, 2024
d89c89a
lock down gptq + awq quantization
bnellnm Sep 23, 2024
3892da2
lockdown more quantization methods
bnellnm Sep 23, 2024
b406755
add gptq_marlin quantization test + fix
bnellnm Sep 23, 2024
0decaa3
format
bnellnm Sep 23, 2024
56b355a
rebase + fix layer_utils
bnellnm Sep 23, 2024
62c33be
split up fullgraph tests
bnellnm Sep 24, 2024
d51c454
fix test pipeline
bnellnm Sep 24, 2024
23bb7cc
remove plain inductor backend
bnellnm Sep 24, 2024
404ff31
format
bnellnm Sep 24, 2024
44977d0
review comments + try to fix marlin_gemm_moe check for AMD
bnellnm Sep 25, 2024
0dfffa6
don't run fullgraph tests on AMD
bnellnm Sep 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ steps:
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py

- label: Core Test # 10min
mirror_hardwares: [amd]
fast_check: true
Expand Down Expand Up @@ -207,6 +207,21 @@ steps:
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
parallelism: 4

- label: "PyTorch Fullgraph Smoke Test"
fast_check: true
source_file_dependencies:
- vllm/
- tests/compile
commands:
- pytest -v -s compile/test_full_graph_smoke.py

- label: "PyTorch Fullgraph Test"
source_file_dependencies:
- vllm/
- tests/compile
commands:
- pytest -v -s compile/test_full_graph.py

- label: Kernels Test %N # 30min each
mirror_hardwares: [amd]
source_file_dependencies:
Expand Down Expand Up @@ -352,7 +367,7 @@ steps:
- tests/distributed/
- vllm/compilation
commands:
- pytest -v -s ./compile/test_full_graph.py
- pytest -v -s ./compile/test_full_graph_multi_gpu.py
- pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus
Expand Down
2 changes: 1 addition & 1 deletion csrc/mamba/mamba_ssm/selective_scan_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
});
std::vector<at::Tensor> result = {out, x.value()};
std::vector<at::Tensor> result = {out};
if (has_z) { result.push_back(out_z); }
return result;
}
Expand Down
4 changes: 2 additions & 2 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor! A, Tensor! B, Tensor! C,"
"Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
"bool delta_softplus,"
"Tensor? index_, Tensor(a! -> *)? x) -> Tensor(a)[]");
"Tensor? index_, Tensor!? x) -> Tensor[]");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);

ops.def(
Expand All @@ -292,7 +292,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? bias_,"
"Tensor? seq_idx_,"
"Tensor? initial_states_,"
"Tensor? final_states_out_,"
"Tensor!? final_states_out_,"
"bool silu_activation) -> Tensor");
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
#endif
Expand Down
45 changes: 8 additions & 37 deletions tests/compile/test_full_graph.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,13 @@
import os

import pytest

from vllm.utils import cuda_device_count_stateless

from ..utils import fork_new_process_for_each_test


@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
@pytest.mark.parametrize("tp_size", [1, 2])
@fork_new_process_for_each_test
def test_full_graph(model, tp_size):

# Skip the test if there are not enough CUDA devices.
if cuda_device_count_stateless() < tp_size:
pytest.skip("Not enough CUDA devices for the test.")

# make sure these models can be captured in full graph mode
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ:
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
from vllm.compilation.backends import vllm_backend

from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
llm = LLM(model=model,
enforce_eager=True,
tensor_parallel_size=tp_size,
disable_custom_all_reduce=True)
from .utils import TEST_MODELS, check_full_graph_support

outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
@pytest.mark.parametrize("model_info", TEST_MODELS)
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
def test_full_graph(model_info, backend):
model = model_info[0]
model_kwargs = model_info[1]
check_full_graph_support(model, model_kwargs, backend, tp_size=1)
22 changes: 22 additions & 0 deletions tests/compile/test_full_graph_multi_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest

from vllm.compilation.backends import vllm_backend
from vllm.utils import cuda_device_count_stateless

from ..utils import fork_new_process_for_each_test
from .utils import TEST_MODELS_SMOKE, check_full_graph_support


@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE)
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
@fork_new_process_for_each_test
def test_full_graph_multi_gpu(model_info, tp_size, backend):
model = model_info[0]
model_kwargs = model_info[1]

# Skip the test if there are not enough CUDA devices.
if cuda_device_count_stateless() < tp_size:
pytest.skip("Not enough CUDA devices for the test.")

check_full_graph_support(model, model_kwargs, backend, tp_size=tp_size)
13 changes: 13 additions & 0 deletions tests/compile/test_full_graph_smoke.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest

from vllm.compilation.backends import vllm_backend

from .utils import TEST_MODELS_SMOKE, check_full_graph_support


@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE)
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
def test_full_graph(model_info, backend):
model = model_info[0]
model_kwargs = model_info[1]
check_full_graph_support(model, model_kwargs, backend, tp_size=1)
104 changes: 104 additions & 0 deletions tests/compile/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import os

import torch

from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.plugins import set_torch_compile_backend
from vllm.utils import is_hip

TEST_MODELS_SMOKE = [
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
"quantization": "compressed-tensors"
}),
("meta-llama/Meta-Llama-3-8B", {}),
]

TEST_MODELS = [
("facebook/opt-125m", {}),
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
"dtype": torch.float16,
"quantization": "compressed-tensors"
}),
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", {
"dtype": torch.float16,
"quantization": "fp8"
}),
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
"quantization": "compressed-tensors"
}),
("meta-llama/Meta-Llama-3-8B", {}),
]

# TODO: enable in pytorch 2.5
if False and is_quant_method_supported("aqlm"): # noqa: SIM223
TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", {
"quantization": "aqlm"
}))

# TODO: enable in pytorch 2.5
if False and is_quant_method_supported("gguf"): # noqa: SIM223
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {
"quantization": "gguf"
}))

if is_quant_method_supported("gptq"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {
"quantization": "gptq"
}))

if is_quant_method_supported("gptq_marlin"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", {
"quantization": "gptq_marlin"
}))

if is_quant_method_supported("gptq_marlin_24"):
TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", {
"quantization": "gptq_marlin_24"
}))

if is_quant_method_supported("marlin"):
TEST_MODELS.append(("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", {
"quantization": "marlin"
}))

if not is_hip() and is_quant_method_supported("awq"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
"quantization": "AWQ"
}))


def check_full_graph_support(model, model_kwargs, backend, tp_size=1):
# make sure these models can be captured in full graph mode
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ:
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"

# Inductor doesn't support fp8/gptq_marlin_24 yet.
quantization = model_kwargs.get("quantization")
if (quantization == "fp8" or quantization == "gptq_marlin"
or quantization == "gptq_marlin_24") and backend != "eager":
return

set_torch_compile_backend(backend)

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
llm = LLM(model=model,
enforce_eager=True,
tensor_parallel_size=tp_size,
disable_custom_all_reduce=True,
**model_kwargs)

outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,12 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
cleanup()


@pytest.fixture(autouse=True)
def dynamo_reset():
yield
torch._dynamo.reset()


@pytest.fixture
def example_prompts() -> List[str]:
prompts = []
Expand Down
37 changes: 37 additions & 0 deletions tests/kernels/test_aqlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch

from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401


def test_aqlm_dequant_opcheck():
codes = torch.randint(-32768,
32767, (22016, 512, 1),
device='cuda',
dtype=torch.int16)
codebooks = torch.rand((2, 65536, 1, 8),
device='cuda',
dtype=torch.float16)
codebook_partition_sizes = [11008, 11008]

opcheck(torch.ops._C.aqlm_dequant,
(codes, codebooks, codebook_partition_sizes))


def test_aqlm_gemm_opcheck():
input = torch.rand((4, 4096), device='cuda', dtype=torch.float16)
codes = torch.randint(-32768,
32767, (12288, 512, 1),
device='cuda',
dtype=torch.int16)
codebooks = torch.rand((3, 65536, 1, 8),
device='cuda',
dtype=torch.float16)
scales = torch.rand((12288, 1, 1, 1), device='cuda', dtype=torch.float16)
codebook_partition_sizes = [4096, 4096, 4096]
bias = None

opcheck(torch.ops._C.aqlm_gemm,
(input, codes, codebooks, scales, codebook_partition_sizes, None))
opcheck(torch.ops._C.aqlm_gemm,
(input, codes, codebooks, scales, codebook_partition_sizes, bias))
9 changes: 6 additions & 3 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ def test_paged_attention(
(output, query, key_cache, value_cache, num_kv_heads, scale,
block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0]))
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))

elif version in ("v2", "rocm"):
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
Expand Down Expand Up @@ -246,7 +247,8 @@ def test_paged_attention(
key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0]))
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))

else:
ops.paged_attention_rocm(
Expand Down Expand Up @@ -274,7 +276,8 @@ def test_paged_attention(
key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]))
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))

else:
raise AssertionError(f"Unknown version: {version}")
Expand Down
38 changes: 38 additions & 0 deletions tests/kernels/test_awq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os

import torch

from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401


def test_awq_dequantize_opcheck():
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
qweight = torch.randint(-2000000000,
2000000000, (8192, 256),
device='cuda',
dtype=torch.int32)
scales = torch.rand((64, 2048), device='cuda', dtype=torch.float16)
zeros = torch.empty((64, 256), device='cuda', dtype=torch.int32)
split_k_iters = 0
thx = 0
thy = 0
opcheck(torch.ops._C.awq_dequantize,
(qweight, scales, zeros, split_k_iters, thx, thy))


def test_awq_gemm_opcheck():
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
input = torch.rand((2, 8192), device='cuda', dtype=torch.float16)
qweight = torch.randint(-2000000000,
2000000000, (8192, 256),
device='cuda',
dtype=torch.int32)
scales = torch.randint(-2000000000,
2000000000, (64, 256),
device='cuda',
dtype=torch.int32)
qzeros = torch.empty((64, 2048), device='cuda', dtype=torch.float16)
split_k_iters = 8
opcheck(torch.ops._C.awq_gemm,
(input, qweight, qzeros, scales, split_k_iters))
Loading
Loading