8
8
import torch
9
9
10
10
from tests .kernels .quant_utils import (native_per_token_group_quant_fp8 ,
11
- native_w8a8_block_matmul ,
12
- per_block_cast_to_fp8 )
11
+ native_w8a8_block_matmul )
13
12
from vllm .config import VllmConfig
14
13
from vllm .model_executor .layers .quantization .utils .fp8_utils import (
15
- per_token_group_quant_fp8 , w8a8_block_fp8_matmul )
14
+ get_col_major_tma_aligned_tensor , per_token_group_quant_fp8 ,
15
+ w8a8_block_fp8_matmul )
16
16
from vllm .platforms import current_platform
17
-
18
- dg_available = False
19
- try :
20
- import deep_gemm
21
- dg_available = True
22
- except ImportError :
23
- pass
17
+ from vllm .utils import has_deep_gemm
18
+ from vllm .utils .deep_gemm import (fp8_gemm_nt , per_block_cast_to_fp8 ,
19
+ per_token_group_cast_to_fp8 )
24
20
25
21
if current_platform .get_device_capability () < (9 , 0 ):
26
22
pytest .skip ("FP8 Triton requires CUDA 9.0 or higher" ,
@@ -106,7 +102,8 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
106
102
@pytest .mark .parametrize (
107
103
"M,N,K,block_size,out_dtype,seed" ,
108
104
itertools .product (M , N , K , BLOCK_SIZE , OUT_DTYPES , SEEDS ))
109
- @pytest .mark .skipif (not dg_available , reason = "DeepGemm kernels not available." )
105
+ @pytest .mark .skipif (not has_deep_gemm (),
106
+ reason = "DeepGemm kernels not available." )
110
107
@torch .inference_mode ()
111
108
def test_w8a8_block_fp8_deep_gemm_matmul (M , N , K , block_size , out_dtype , seed ):
112
109
# only aligned sizes
@@ -120,9 +117,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
120
117
A_fp32 = (torch .rand (M , K , dtype = torch .float32 ) - 0.5 ) * 2 * fp8_max
121
118
B_fp32 = (torch .rand (N , K , dtype = torch .float32 ) - 0.5 ) * 2 * fp8_max
122
119
123
- _ , block_k = block_size [0 ], block_size [1 ]
124
-
125
- A_fp8 , As_fp8 = per_token_group_quant_fp8 (A_fp32 , block_k )
120
+ A_fp8 , As_fp8 = per_token_group_cast_to_fp8 (A_fp32 , block_size [1 ])
126
121
B_fp8 , Bs_fp8 = per_block_cast_to_fp8 (B_fp32 )
127
122
128
123
As = As_fp8 .to (torch .float32 )
@@ -132,14 +127,14 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
132
127
out_dtype )
133
128
134
129
# Transpose earlier so that the testing will not trigger transposing kernels
135
- As_fp8 = deep_gemm . get_col_major_tma_aligned_tensor (As_fp8 )
130
+ As_fp8 = get_col_major_tma_aligned_tensor (As_fp8 )
136
131
137
132
out = torch .zeros ((M , N ), device = 'cuda' , dtype = out_dtype )
138
133
139
134
assert As_fp8 .shape == (M , (K + 127 ) //
140
135
128 ), f"{ As_fp8 .shape } != { (M , (K + 127 ) // 128 )} "
141
136
142
- deep_gemm . gemm_fp8_fp8_bf16_nt ((A_fp8 , As_fp8 ), (B_fp8 , Bs_fp8 ), out )
137
+ fp8_gemm_nt ((A_fp8 , As_fp8 ), (B_fp8 , Bs_fp8 ), out )
143
138
144
139
rel_diff = (torch .mean (
145
140
torch .abs (out .to (torch .float32 ) - ref_out .to (torch .float32 ))) /
0 commit comments