15
15
from torch .distributed ._tensor import DTensor
16
16
17
17
from torchao .prototype .mx_formats .config import (
18
+ MXFP8Dim1CastKernelChoice ,
18
19
MXGemmKernelChoice ,
19
20
MXInferenceLinearConfig ,
20
21
MXLinearConfig ,
21
22
)
22
- from torchao .prototype .mx_formats .kernels import triton_to_mxfp8_dim1
23
+ from torchao .prototype .mx_formats .kernels import (
24
+ mxfp8_quantize_cuda ,
25
+ triton_to_mxfp8_dim1 ,
26
+ )
23
27
from torchao .prototype .mx_formats .mx_tensor import MXTensor
24
28
from torchao .quantization .transform_module import (
25
29
register_quantize_module_handler ,
26
30
)
27
31
28
32
29
- def _triton_to_mxfp8_dim1_wrapper (
30
- a , block_size , elem_dtype , hp_dtype , gemm_kernel_choice
33
+ def _to_mxfp8_dim1_kernel_wrapper (
34
+ a ,
35
+ block_size ,
36
+ elem_dtype ,
37
+ hp_dtype ,
38
+ gemm_kernel_choice ,
39
+ cast_kernel_choice ,
31
40
):
32
- a_data , a_scale = triton_to_mxfp8_dim1 (a , block_size )
41
+ if cast_kernel_choice == MXFP8Dim1CastKernelChoice .TRITON :
42
+ a_data , a_scale = triton_to_mxfp8_dim1 (a , block_size )
43
+ elif cast_kernel_choice == MXFP8Dim1CastKernelChoice .CUDA :
44
+ _ , a_data , _ , a_scale = mxfp8_quantize_cuda (
45
+ a ,
46
+ rowwise = False ,
47
+ colwise = True ,
48
+ scaling_mode = "floor" ,
49
+ )
50
+ else :
51
+ raise ValueError (f"must be one of [CUDA, TRITON], got { cast_kernel_choice } " )
52
+
33
53
if isinstance (a_data , DTensor ):
34
54
assert isinstance (a_scale , DTensor )
35
55
a_data_local = a_data .to_local ()
@@ -86,15 +106,15 @@ def forward(
86
106
grad_elem_dtype : Any ,
87
107
block_size : int ,
88
108
gemm_kernel_choice : MXGemmKernelChoice ,
89
- use_fp8_dim1_cast_triton_kernel : bool ,
109
+ mxfp8_cast_kernel_choice : MXFP8Dim1CastKernelChoice ,
90
110
):
91
111
ctx .save_for_backward (input_hp , weight_hp )
92
112
ctx .in_elem_dtype = in_elem_dtype
93
113
ctx .w_elem_dtype = w_elem_dtype
94
114
ctx .grad_elem_dtype = grad_elem_dtype
95
115
ctx .block_size = block_size
96
116
ctx .gemm_kernel_choice = gemm_kernel_choice
97
- ctx .use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel
117
+ ctx .mxfp8_cast_kernel_choice = mxfp8_cast_kernel_choice
98
118
99
119
# input @ weight_t = output
100
120
input_orig_shape = input_hp .shape
@@ -119,7 +139,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
119
139
grad_elem_dtype = ctx .grad_elem_dtype
120
140
block_size = ctx .block_size
121
141
gemm_kernel_choice = ctx .gemm_kernel_choice
122
- use_fp8_dim1_cast_triton_kernel = ctx .use_fp8_dim1_cast_triton_kernel
142
+ mxfp8_cast_kernel_choice = ctx .mxfp8_cast_kernel_choice
123
143
124
144
grad_output_orig_shape = grad_output_hp .shape
125
145
grad_output_hp_r = grad_output_hp .reshape (- 1 , grad_output_orig_shape [- 1 ])
@@ -135,9 +155,14 @@ def backward(ctx, grad_output_hp: torch.Tensor):
135
155
gemm_kernel_choice = gemm_kernel_choice ,
136
156
)
137
157
138
- if use_fp8_dim1_cast_triton_kernel :
139
- weight_mx_dim1 = _triton_to_mxfp8_dim1_wrapper (
140
- weight_hp , block_size , w_elem_dtype , weight_hp .dtype , gemm_kernel_choice
158
+ if mxfp8_cast_kernel_choice != MXFP8Dim1CastKernelChoice .TORCH :
159
+ weight_mx_dim1 = _to_mxfp8_dim1_kernel_wrapper (
160
+ weight_hp ,
161
+ block_size ,
162
+ w_elem_dtype ,
163
+ weight_hp .dtype ,
164
+ gemm_kernel_choice ,
165
+ mxfp8_cast_kernel_choice ,
141
166
)
142
167
else :
143
168
weight_hp_t_c = weight_hp .t ().contiguous ()
@@ -153,13 +178,14 @@ def backward(ctx, grad_output_hp: torch.Tensor):
153
178
)
154
179
155
180
# input_t @ grad_output = grad_weight
156
- if use_fp8_dim1_cast_triton_kernel :
157
- grad_output_mx_dim1 = _triton_to_mxfp8_dim1_wrapper (
181
+ if mxfp8_cast_kernel_choice != MXFP8Dim1CastKernelChoice . TORCH :
182
+ grad_output_mx_dim1 = _to_mxfp8_dim1_kernel_wrapper (
158
183
grad_output_hp_r ,
159
184
block_size ,
160
185
grad_elem_dtype ,
161
186
grad_output_hp_r .dtype ,
162
187
gemm_kernel_choice ,
188
+ mxfp8_cast_kernel_choice ,
163
189
)
164
190
else :
165
191
grad_output_mx_dim1 = MXTensor .to_mx (
@@ -169,13 +195,14 @@ def backward(ctx, grad_output_hp: torch.Tensor):
169
195
gemm_kernel_choice = gemm_kernel_choice ,
170
196
)
171
197
172
- if use_fp8_dim1_cast_triton_kernel :
173
- input_t_mx_dim0_tmp = _triton_to_mxfp8_dim1_wrapper (
198
+ if mxfp8_cast_kernel_choice != MXFP8Dim1CastKernelChoice . TORCH :
199
+ input_t_mx_dim0_tmp = _to_mxfp8_dim1_kernel_wrapper (
174
200
input_hp_r ,
175
201
block_size ,
176
202
in_elem_dtype ,
177
203
input_hp_r .dtype ,
178
204
gemm_kernel_choice ,
205
+ mxfp8_cast_kernel_choice ,
179
206
)
180
207
input_t_mx_dim0 = input_t_mx_dim0_tmp .t ()
181
208
else :
@@ -232,7 +259,7 @@ def forward(self, x):
232
259
config .elem_dtype_grad_output_override or config .elem_dtype ,
233
260
config .block_size ,
234
261
config .gemm_kernel_choice ,
235
- config .use_fp8_dim1_cast_triton_kernel ,
262
+ config .mxfp8_cast_kernel_choice ,
236
263
)
237
264
if self .bias is not None :
238
265
y = y + self .bias
0 commit comments