-
Notifications
You must be signed in to change notification settings - Fork 316
Description
VLLM Torch.compile Issue Tracker
Summary
This document tracks the existing issue with the way VLLM uses torch.compile
and tensor subclasses.
TLDR: VLLM doesn't setup aotdispatch
correctly, causing subclass flattening to not take place.
@zou3519 has theoretically fixed this issue with: vllm-project/vllm#17057, which enables standalone compile.
Testing Results
MXFP4 Model Testing
Command:
python vllm/sample_output.py --model_name "data/mxfp4-Qwen2-7B-Instruct" --compile True
Issue: The swizzle kernel needs to be enabled because without it, the dynamic control flow in to_blocked
will error with how vllm bakes out different graphs.
Error encountered:
torch._inductor.exc.InductorError: RuntimeError: Failed to import /tmp/torchinductor_drisspg/zw/czwhsx3gfghqtfk5wbhit2p2xpgr3yxil54pbsi2mtnirmrnrths.py
IndentationError: unexpected indent (czwhsx3gfghqtfk5wbhit2p2xpgr3yxil54pbsi2mtnirmrnrths.py, line 173)
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
Note: I suspect that it doesn't like inlining the user defined Triton kernel correctly.
cc @oulgen if you have any ideas here
FP8 Model Testing
Command:
python vllm/sample_output.py --model_name "data/fp8-Qwen2-7B-Instruct" --compile True
Results: ✅ Working
First run compilation time:
INFO 05-21 22:29:11 [monitor.py:33] torch.compile takes 136.26 s in total
Sample outputs:
-
Prompt: 'Why is Pytorch 2.0 the best machine learning compiler?'
Generated: ' PyTorch 2.0, currently not released officially, is anticipated to be a significant upgrade in several areas including performance, features, and usability...' -
Prompt: 'Hello, my name is'
Generated: ' Mandy Lowry, and I am a certified financial planner and a long-time friend of the Girls' Town...' -
Prompt: 'The president of the United States is'
Generated: " the leader of the government of the United States. The president is also the commander-in-chief of the United States Armed Forces..." -
Prompt: 'The capital of France is'
Generated: '__.\nEdinburgh\nGeneva\nParis\nLondon\n答案:\n\nC...' -
Prompt: 'The future of AI is'
Generated: ' moving closer to reality with the launch of a revolutionary new AI-powered software platform called Hummingbird...'
Second run compilation time (cached):
INFO 05-21 22:30:32 [backends.py:134] Directly load the compiled graph(s) for shape None from the cache, took 8.395 s
INFO 05-21 22:30:42 [monitor.py:33] torch.compile takes 7.26 s in total
Test Script
import os
import random
import numpy as np
import torch
from vllm import LLM, SamplingParams
from rich import print
def set_seed(seed):
"""Set seeds for reproducibility"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def main(
model_name: str = "Qwen/Qwen2-7B-Instruct",
max_tokens=64,
tp_size: int = 1,
compile: bool = True,
):
# Set seed before creating the LLM
set_seed(42)
# Environment variables for VLLM configuration
# os.environ["VLLM_TORCH_PROFILER_DIR"] = "data/flex_profile" # Enable torch profiler
os.environ["VLLM_USE_V1"] = "1"
# os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION_VLLM_V1"
# os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
# os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
os.environ["VLLM_TEST_STANDALONE_COMPILE"] = "1"
# Create sampling params
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
seed=42,
max_tokens=max_tokens
)
# Create LLM instance
print(f"Using Model name: {model_name}")
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size,
enforce_eager=not compile
)
# Test prompts
prompts = [
"Why is Pytorch 2.0 the best machine learning compiler?",
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Generate outputs
outputs = llm.generate(prompts, sampling_params)
# Print results
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
if __name__ == "__main__":
from jsonargparse import CLI
CLI(main)
Metadata
Metadata
Assignees
Type
Projects
Milestone
Relationships
Parent issue
Activity
zou3519 commentedon Jun 2, 2025
@drisspg can you share the contents of the file?
drisspg commentedon Jun 3, 2025
Here is the Inductor file: https://www.internalfb.com/intern/paste/P1830942194/
TLP: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/drisspg/custom-8c7f0c39/rank_0/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000
drisspg commentedon Jun 3, 2025
Okay so I sat next to Richard and looked and actually looked at this, its kinda sad but the actual problem is just that I have a triple quoted doc block for the kernel and this messes up some external quotations