Skip to content

Tensor Subclass + VLLM Compile #2239

@drisspg

Description

@drisspg
Contributor

VLLM Torch.compile Issue Tracker

Summary

This document tracks the existing issue with the way VLLM uses torch.compile and tensor subclasses.

TLDR: VLLM doesn't setup aotdispatch correctly, causing subclass flattening to not take place.

@zou3519 has theoretically fixed this issue with: vllm-project/vllm#17057, which enables standalone compile.

Testing Results

MXFP4 Model Testing

Command:

python vllm/sample_output.py --model_name "data/mxfp4-Qwen2-7B-Instruct" --compile True

Issue: The swizzle kernel needs to be enabled because without it, the dynamic control flow in to_blocked will error with how vllm bakes out different graphs.

Error encountered:

torch._inductor.exc.InductorError: RuntimeError: Failed to import /tmp/torchinductor_drisspg/zw/czwhsx3gfghqtfk5wbhit2p2xpgr3yxil54pbsi2mtnirmrnrths.py
IndentationError: unexpected indent (czwhsx3gfghqtfk5wbhit2p2xpgr3yxil54pbsi2mtnirmrnrths.py, line 173)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Note: I suspect that it doesn't like inlining the user defined Triton kernel correctly.
cc @oulgen if you have any ideas here

FP8 Model Testing

Command:

python vllm/sample_output.py --model_name "data/fp8-Qwen2-7B-Instruct" --compile True

Results: ✅ Working

First run compilation time:

INFO 05-21 22:29:11 [monitor.py:33] torch.compile takes 136.26 s in total

Sample outputs:

  • Prompt: 'Why is Pytorch 2.0 the best machine learning compiler?'
    Generated: ' PyTorch 2.0, currently not released officially, is anticipated to be a significant upgrade in several areas including performance, features, and usability...'

  • Prompt: 'Hello, my name is'
    Generated: ' Mandy Lowry, and I am a certified financial planner and a long-time friend of the Girls' Town...'

  • Prompt: 'The president of the United States is'
    Generated: " the leader of the government of the United States. The president is also the commander-in-chief of the United States Armed Forces..."

  • Prompt: 'The capital of France is'
    Generated: '__.\nEdinburgh\nGeneva\nParis\nLondon\n答案:\n\nC...'

  • Prompt: 'The future of AI is'
    Generated: ' moving closer to reality with the launch of a revolutionary new AI-powered software platform called Hummingbird...'

Second run compilation time (cached):

INFO 05-21 22:30:32 [backends.py:134] Directly load the compiled graph(s) for shape None from the cache, took 8.395 s
INFO 05-21 22:30:42 [monitor.py:33] torch.compile takes 7.26 s in total

Test Script

import os
import random
import numpy as np
import torch

from vllm import LLM, SamplingParams
from rich import print


def set_seed(seed):
    """Set seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def main(
    model_name: str = "Qwen/Qwen2-7B-Instruct",
    max_tokens=64,
    tp_size: int = 1,
    compile: bool = True,
):
    # Set seed before creating the LLM
    set_seed(42)

    # Environment variables for VLLM configuration
    # os.environ["VLLM_TORCH_PROFILER_DIR"] = "data/flex_profile"  # Enable torch profiler
    os.environ["VLLM_USE_V1"] = "1"
    # os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION_VLLM_V1"
    # os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
    # os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
    os.environ["VLLM_TEST_STANDALONE_COMPILE"] = "1"

    # Create sampling params
    sampling_params = SamplingParams(
        temperature=0.8, 
        top_p=0.95, 
        seed=42, 
        max_tokens=max_tokens
    )
    
    # Create LLM instance
    print(f"Using Model name: {model_name}")
    llm = LLM(
        model=model_name, 
        tensor_parallel_size=tp_size, 
        enforce_eager=not compile
    )
    
    # Test prompts
    prompts = [
        "Why is Pytorch 2.0 the best machine learning compiler?",
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    
    # Generate outputs
    outputs = llm.generate(prompts, sampling_params)

    # Print results
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


if __name__ == "__main__":
    from jsonargparse import CLI
    CLI(main)

Activity

zou3519

zou3519 commented on Jun 2, 2025

@zou3519
Contributor

@drisspg can you share the contents of the file?

drisspg

drisspg commented on Jun 3, 2025

@drisspg
ContributorAuthor

Okay so I sat next to Richard and looked and actually looked at this, its kinda sad but the actual problem is just that I have a triple quoted doc block for the kernel and this messes up some external quotations

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Development

    No branches or pull requests

      Participants

      @zou3519@drisspg

      Issue actions

        Tensor Subclass + VLLM Compile · Issue #2239 · pytorch/ao