Skip to content

Commit 355fa4b

Browse files
committed
[Bugfix] Fix triton import with local TritonPlaceholder
Signed-off-by: Mengqing Cao <[email protected]>
1 parent c777df7 commit 355fa4b

29 files changed

+79
-81
lines changed

benchmarks/kernels/benchmark_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010

1111
import ray
1212
import torch
13-
import triton
1413
from ray.experimental.tqdm_ray import tqdm
1514
from transformers import AutoConfig
1615

1716
from vllm.model_executor.layers.fused_moe.fused_moe import *
1817
from vllm.platforms import current_platform
18+
from vllm.triton_utils import triton
1919
from vllm.utils import FlexibleArgumentParser
2020

2121
FP8_DTYPE = current_platform.fp8_dtype()

benchmarks/kernels/benchmark_rmsnorm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
from typing import Optional, Union
55

66
import torch
7-
import triton
87
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
98
from torch import nn
109

1110
from vllm import _custom_ops as vllm_ops
11+
from vllm.triton_utils import triton
1212

1313

1414
class HuggingFaceRMSNorm(nn.Module):

benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
# Import DeepGEMM functions
77
import deep_gemm
88
import torch
9-
import triton
109
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor
1110

1211
# Import vLLM functions
1312
from vllm import _custom_ops as ops
1413
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
1514
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
15+
from vllm.triton_utils import triton
1616

1717

1818
# Copied from

tests/kernels/attention/test_flashmla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55

66
import pytest
77
import torch
8-
import triton
98

109
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
1110
get_mla_metadata,
1211
is_flashmla_supported)
12+
from vllm.triton_utils import triton
1313

1414

1515
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:

vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import torch
4-
import triton
5-
import triton.language as tl
4+
5+
from vllm.triton_utils import tl, triton
66

77

88
def blocksparse_flash_attn_varlen_fwd(

vllm/attention/ops/blocksparse_attention/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
import numpy as np
1010
import torch
11-
import triton
11+
12+
from vllm.triton_utils import triton
1213

1314

1415
class csr_matrix:

vllm/attention/ops/chunked_prefill_paged_decode.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77
# - Thomas Parnell <[email protected]>
88

99
import torch
10-
import triton
11-
import triton.language as tl
1210

1311
from vllm import _custom_ops as ops
1412
from vllm.platforms.rocm import use_rocm_custom_paged_attention
13+
from vllm.triton_utils import tl, triton
1514

1615
from .prefix_prefill import context_attention_fwd
1716

vllm/attention/ops/prefix_prefill.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
55

66
import torch
7-
import triton
8-
import triton.language as tl
97

108
from vllm.platforms import current_platform
9+
from vllm.triton_utils import tl, triton
1110

1211
# Static kernels parameters
1312
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64

vllm/attention/ops/triton_decode_attention.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,8 @@
3030

3131
import logging
3232

33-
import triton
34-
import triton.language as tl
35-
3633
from vllm.platforms import current_platform
34+
from vllm.triton_utils import tl, triton
3735

3836
is_hip_ = current_platform.is_rocm()
3937

vllm/attention/ops/triton_flash_attention.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,10 @@
2525
from typing import Optional
2626

2727
import torch
28-
import triton
29-
import triton.language as tl
3028

3129
from vllm import _custom_ops as ops
3230
from vllm.platforms import current_platform
31+
from vllm.triton_utils import tl, triton
3332

3433
SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd']
3534

0 commit comments

Comments
 (0)