Skip to content

Commit f9bc5a0

Browse files
authored
[Bugfix] Fix triton import with local TritonPlaceholder (#17446)
Signed-off-by: Mengqing Cao <[email protected]>
1 parent 05e1f96 commit f9bc5a0

30 files changed

+171
-81
lines changed

benchmarks/kernels/benchmark_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010

1111
import ray
1212
import torch
13-
import triton
1413
from ray.experimental.tqdm_ray import tqdm
1514
from transformers import AutoConfig
1615

1716
from vllm.model_executor.layers.fused_moe.fused_moe import *
1817
from vllm.platforms import current_platform
18+
from vllm.triton_utils import triton
1919
from vllm.utils import FlexibleArgumentParser
2020

2121
FP8_DTYPE = current_platform.fp8_dtype()

benchmarks/kernels/benchmark_rmsnorm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
from typing import Optional, Union
55

66
import torch
7-
import triton
87
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
98
from torch import nn
109

1110
from vllm import _custom_ops as vllm_ops
11+
from vllm.triton_utils import triton
1212

1313

1414
class HuggingFaceRMSNorm(nn.Module):

benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
# Import DeepGEMM functions
77
import deep_gemm
88
import torch
9-
import triton
109
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor
1110

1211
# Import vLLM functions
1312
from vllm import _custom_ops as ops
1413
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
1514
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
15+
from vllm.triton_utils import triton
1616

1717

1818
# Copied from

tests/kernels/attention/test_flashmla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55

66
import pytest
77
import torch
8-
import triton
98

109
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
1110
get_mla_metadata,
1211
is_flashmla_supported)
12+
from vllm.triton_utils import triton
1313

1414

1515
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:

tests/test_triton_utils.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import sys
4+
import types
5+
from unittest import mock
6+
7+
from vllm.triton_utils.importing import (TritonLanguagePlaceholder,
8+
TritonPlaceholder)
9+
10+
11+
def test_triton_placeholder_is_module():
12+
triton = TritonPlaceholder()
13+
assert isinstance(triton, types.ModuleType)
14+
assert triton.__name__ == "triton"
15+
16+
17+
def test_triton_language_placeholder_is_module():
18+
triton_language = TritonLanguagePlaceholder()
19+
assert isinstance(triton_language, types.ModuleType)
20+
assert triton_language.__name__ == "triton.language"
21+
22+
23+
def test_triton_placeholder_decorators():
24+
triton = TritonPlaceholder()
25+
26+
@triton.jit
27+
def foo(x):
28+
return x
29+
30+
@triton.autotune
31+
def bar(x):
32+
return x
33+
34+
@triton.heuristics
35+
def baz(x):
36+
return x
37+
38+
assert foo(1) == 1
39+
assert bar(2) == 2
40+
assert baz(3) == 3
41+
42+
43+
def test_triton_placeholder_decorators_with_args():
44+
triton = TritonPlaceholder()
45+
46+
@triton.jit(debug=True)
47+
def foo(x):
48+
return x
49+
50+
@triton.autotune(configs=[], key="x")
51+
def bar(x):
52+
return x
53+
54+
@triton.heuristics(
55+
{"BLOCK_SIZE": lambda args: 128 if args["x"] > 1024 else 64})
56+
def baz(x):
57+
return x
58+
59+
assert foo(1) == 1
60+
assert bar(2) == 2
61+
assert baz(3) == 3
62+
63+
64+
def test_triton_placeholder_language():
65+
lang = TritonLanguagePlaceholder()
66+
assert isinstance(lang, types.ModuleType)
67+
assert lang.__name__ == "triton.language"
68+
assert lang.constexpr is None
69+
assert lang.dtype is None
70+
assert lang.int64 is None
71+
72+
73+
def test_triton_placeholder_language_from_parent():
74+
triton = TritonPlaceholder()
75+
lang = triton.language
76+
assert isinstance(lang, TritonLanguagePlaceholder)
77+
78+
79+
def test_no_triton_fallback():
80+
# clear existing triton modules
81+
sys.modules.pop("triton", None)
82+
sys.modules.pop("triton.language", None)
83+
sys.modules.pop("vllm.triton_utils", None)
84+
sys.modules.pop("vllm.triton_utils.importing", None)
85+
86+
# mock triton not being installed
87+
with mock.patch.dict(sys.modules, {"triton": None}):
88+
from vllm.triton_utils import HAS_TRITON, tl, triton
89+
assert HAS_TRITON is False
90+
assert triton.__class__.__name__ == "TritonPlaceholder"
91+
assert triton.language.__class__.__name__ == "TritonLanguagePlaceholder"
92+
assert tl.__class__.__name__ == "TritonLanguagePlaceholder"

vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import torch
4-
import triton
5-
import triton.language as tl
4+
5+
from vllm.triton_utils import tl, triton
66

77

88
def blocksparse_flash_attn_varlen_fwd(

vllm/attention/ops/blocksparse_attention/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
import numpy as np
1010
import torch
11-
import triton
11+
12+
from vllm.triton_utils import triton
1213

1314

1415
class csr_matrix:

vllm/attention/ops/chunked_prefill_paged_decode.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77
# - Thomas Parnell <[email protected]>
88

99
import torch
10-
import triton
11-
import triton.language as tl
1210

1311
from vllm import _custom_ops as ops
1412
from vllm.platforms.rocm import use_rocm_custom_paged_attention
13+
from vllm.triton_utils import tl, triton
1514

1615
from .prefix_prefill import context_attention_fwd
1716

vllm/attention/ops/prefix_prefill.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
55

66
import torch
7-
import triton
8-
import triton.language as tl
97

108
from vllm.platforms import current_platform
9+
from vllm.triton_utils import tl, triton
1110

1211
# Static kernels parameters
1312
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64

vllm/attention/ops/triton_decode_attention.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,8 @@
3030

3131
import logging
3232

33-
import triton
34-
import triton.language as tl
35-
3633
from vllm.platforms import current_platform
34+
from vllm.triton_utils import tl, triton
3735

3836
is_hip_ = current_platform.is_rocm()
3937

0 commit comments

Comments
 (0)