@@ -122,7 +122,7 @@ def forward(
122
122
round_scales_to_power_of_2 = True ,
123
123
)
124
124
A_scaled = A .to (torch .float32 ) * A_scales
125
- A_fp8_row_major = to_fp8_saturated (A_scaled , torch .float8_e4m3fn )
125
+ A_data_row_major = to_fp8_saturated (A_scaled , torch .float8_e4m3fn )
126
126
127
127
# Convert B to float8, column-major for right operand of grouped GEMM.
128
128
# B_t shape: (E, K, N)
@@ -136,18 +136,18 @@ def forward(
136
136
round_scales_to_power_of_2 = True ,
137
137
)
138
138
B_t_scaled = B_t .to (torch .float32 ) * B_t_scales
139
- B_t_fp8_col_major = to_fp8_saturated (B_t_scaled , torch .float8_e4m3fn )
139
+ B_t_data_col_major = to_fp8_saturated (B_t_scaled , torch .float8_e4m3fn )
140
140
141
141
# Store what we need for backward.
142
142
ctx .save_for_backward (A , B_t , offs )
143
143
ctx .out_dtype = out_dtype
144
144
145
145
# Perform scaled grouped GEMM and return result.
146
146
# output shape: scaled grouped mm of (M,K) @ (B,K,N) = (M,N)
147
- assert not _is_column_major (A_fp8_row_major ), (
147
+ assert not _is_column_major (A_data_row_major ), (
148
148
"A must be row-major for output = A @ B"
149
149
)
150
- assert _is_column_major (B_t_fp8_col_major ), (
150
+ assert _is_column_major (B_t_data_col_major ), (
151
151
"B must be column-major for output = A @ B"
152
152
)
153
153
@@ -157,8 +157,8 @@ def forward(
157
157
A_scales = A_scales .squeeze (- 1 )
158
158
B_t_scales = B_t_scales .squeeze (1 )
159
159
return torch ._scaled_grouped_mm (
160
- A_fp8_row_major ,
161
- B_t_fp8_col_major ,
160
+ A_data_row_major ,
161
+ B_t_data_col_major ,
162
162
A_scales .reciprocal (), # Reciprocals are needed for rescaling the output.
163
163
B_t_scales .reciprocal (),
164
164
offs ,
@@ -184,13 +184,13 @@ def backward(ctx, grad_output: torch.Tensor):
184
184
round_scales_to_power_of_2 = True ,
185
185
)
186
186
grad_output_scaled = grad_output .to (torch .float32 ) * grad_output_scales
187
- grad_output_fp8_row_major = to_fp8_saturated (
187
+ grad_output_data_row_major = to_fp8_saturated (
188
188
grad_output_scaled , torch .float8_e4m3fn
189
189
)
190
190
191
191
# Compute B fp8 column-major for right operand of grouped GEMM:
192
192
# grad_A = grad_output @ B.
193
- B_fp8_col_major , B_scales = triton_fp8_rowwise_3d_transpose_rhs (
193
+ B_data_col_major , B_scales = triton_fp8_rowwise_3d_transpose_rhs (
194
194
B_t ._data if hasattr (B_t , "_data" ) else B_t ,
195
195
output_dtype = torch .float8_e4m3fn ,
196
196
round_scales_to_power_of_2 = True ,
@@ -199,10 +199,10 @@ def backward(ctx, grad_output: torch.Tensor):
199
199
# Compute grad_A.
200
200
# grad_A = grad_output @ B
201
201
# grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
202
- assert not _is_column_major (grad_output_fp8_row_major ), (
202
+ assert not _is_column_major (grad_output_data_row_major ), (
203
203
"grad_output must be row-major for grad_A = grad_output @ B"
204
204
)
205
- assert _is_column_major (B_fp8_col_major ), (
205
+ assert _is_column_major (B_data_col_major ), (
206
206
"B must be column-major for grad_A = grad_output @ B"
207
207
)
208
208
@@ -212,8 +212,8 @@ def backward(ctx, grad_output: torch.Tensor):
212
212
grad_output_scales = grad_output_scales .squeeze (- 1 )
213
213
B_scales = B_scales .squeeze (1 )
214
214
grad_A = torch ._scaled_grouped_mm (
215
- grad_output_fp8_row_major ,
216
- B_fp8_col_major ,
215
+ grad_output_data_row_major ,
216
+ B_data_col_major ,
217
217
grad_output_scales .reciprocal (),
218
218
B_scales .reciprocal (),
219
219
offs ,
@@ -227,18 +227,18 @@ def backward(ctx, grad_output: torch.Tensor):
227
227
# Convert transpose of grad_output to float8, row-major for left operand of grouped GEMM
228
228
# needed for grad_B: grad_output_t @ A
229
229
# Use transpose method to avoid uncoalesced memory accesses.
230
- grad_out_fp8_colwise , grad_out_scales = triton_fp8_per_group_colwise_scales (
230
+ grad_out_data_colwise , grad_out_scales = triton_fp8_per_group_colwise_scales (
231
231
grad_output .t ()
232
232
.contiguous ()
233
233
.t (), # Quantization is over 2x faster when input is col major, even with this transformation
234
234
offs ,
235
235
torch .float8_e4m3fn ,
236
236
round_scales_to_power_of_2 = True ,
237
237
)
238
- grad_output_t_fp8_row_major = grad_out_fp8_colwise .t ()
238
+ grad_output_t_data_row_major = grad_out_data_colwise .t ()
239
239
grad_output_t_scales = grad_out_scales .t ()
240
240
241
- A_fp8_col_major , A_scales = triton_fp8_per_group_colwise_scales (
241
+ A_data_col_major , A_scales = triton_fp8_per_group_colwise_scales (
242
242
A .t ()
243
243
.contiguous ()
244
244
.t (), # Quantization is over 2x faster when input is col major, even with this transformation
@@ -249,19 +249,19 @@ def backward(ctx, grad_output: torch.Tensor):
249
249
250
250
# Compute grad_B = grad_output_t @ A.
251
251
# grad_B = grad_output_t @ A
252
- assert not _is_column_major (grad_output_t_fp8_row_major ), (
252
+ assert not _is_column_major (grad_output_t_data_row_major ), (
253
253
"grad_output_t must be row-major for grad_B = grad_output_t @ A"
254
254
)
255
- assert _is_column_major (A_fp8_col_major ), (
255
+ assert _is_column_major (A_data_col_major ), (
256
256
"A must be column-major for grad_B = grad_output_t @ A"
257
257
)
258
258
259
259
# Per-token group scales computed via triton kernels above do not have
260
260
# the empty dim like the scales computed via tensor_to_scale, so we need
261
261
# don't need to squeeze here.
262
262
grad_B = torch ._scaled_grouped_mm (
263
- grad_output_t_fp8_row_major ,
264
- A_fp8_col_major ,
263
+ grad_output_t_data_row_major ,
264
+ A_data_col_major ,
265
265
grad_output_t_scales .reciprocal (),
266
266
A_scales .reciprocal (),
267
267
offs ,
@@ -295,13 +295,15 @@ def forward(
295
295
ctx .out_dtype = out_dtype
296
296
ctx .emulated = emulated
297
297
298
- # A_mx shape: (M, K)
298
+ # A_data shape: (M, K)
299
299
# A_scale shape: (M, K//block_size)
300
- A_scale , A_mx = to_mx (A , elem_dtype = torch .float8_e4m3fn , block_size = block_size )
300
+ A_scale , A_data = to_mx (
301
+ A , elem_dtype = torch .float8_e4m3fn , block_size = block_size
302
+ )
301
303
302
- # B_mx shape: (E, N, K)
304
+ # B_data shape: (E, N, K)
303
305
# B_scale shape: (E, N, K//block_size)
304
- B_scales , B_mx = to_mx (
306
+ B_scales , B_data = to_mx (
305
307
B_t .transpose (- 2 , - 1 ),
306
308
elem_dtype = torch .float8_e4m3fn ,
307
309
block_size = block_size ,
@@ -315,9 +317,9 @@ def forward(
315
317
else fbgemm_mxfp8_grouped_mm_2d_3d
316
318
)
317
319
out = mxfp8_2d_3d_grouped_mm (
318
- A_mx ,
320
+ A_data ,
319
321
A_scale ,
320
- B_mx ,
322
+ B_data ,
321
323
B_scales ,
322
324
offs = offs ,
323
325
block_size = block_size ,
@@ -332,15 +334,15 @@ def backward(ctx, grad_out: torch.Tensor):
332
334
out_dtype = ctx .out_dtype
333
335
emulated = ctx .emulated
334
336
335
- # grad_out_mx shape: (M, N)
337
+ # grad_out_data shape: (M, N)
336
338
# grad_out_scale shape: (M, N//block_size)
337
- grad_out_scale , grad_out_mx = to_mx (
339
+ grad_out_scale , grad_out_data = to_mx (
338
340
grad_out , elem_dtype = torch .float8_e4m3fn , block_size = block_size
339
341
)
340
342
341
- # B_mx shape: (E, K, N)
343
+ # B_data shape: (E, K, N)
342
344
# B_scale shape: (E, K, N//block_size)
343
- B_scales , B_mx = to_mx (
345
+ B_scales , B_data = to_mx (
344
346
# TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
345
347
B_t .contiguous (),
346
348
elem_dtype = torch .float8_e4m3fn ,
@@ -354,43 +356,43 @@ def backward(ctx, grad_out: torch.Tensor):
354
356
else fbgemm_mxfp8_grouped_mm_2d_3d
355
357
)
356
358
grad_A = mxfp8_2d_3d_grouped_mm (
357
- grad_out_mx ,
359
+ grad_out_data ,
358
360
grad_out_scale ,
359
- B_mx ,
361
+ B_data ,
360
362
B_scales ,
361
363
offs = offs ,
362
364
out_dtype = out_dtype ,
363
365
)
364
366
365
- # grad_out_t_mx shape: (N, M)
367
+ # grad_out_t_data shape: (N, M)
366
368
# grad_out_t_scales shape: (N, M//block_size)
367
- grad_out_t_scales , grad_out_t_mx = to_mx (
369
+ grad_out_t_scales , grad_out_t_data = to_mx (
368
370
# TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
369
371
grad_out .transpose (- 2 , - 1 ).contiguous (),
370
372
elem_dtype = torch .float8_e4m3fn ,
371
373
block_size = block_size ,
372
374
)
373
375
374
376
# Transpose A so we can scale along the M dimension, then un-transpose.
375
- # A_t_mx shape: (K, M)
377
+ # A_t_data shape: (K, M)
376
378
# A_t_scales shape: (K, M//block_size)
377
- A_t_scales , A_t_mx = to_mx (
379
+ A_t_scales , A_t_data = to_mx (
378
380
A .transpose (- 2 , - 1 ).contiguous (),
379
381
elem_dtype = torch .float8_e4m3fn ,
380
382
block_size = block_size ,
381
383
)
382
384
383
- # A_mx shape = (M, K)
384
- A_mx = A_t_mx .transpose (- 2 , - 1 )
385
+ # A_data shape = (M, K)
386
+ A_data = A_t_data .transpose (- 2 , - 1 )
385
387
386
388
# A_scales shape = (M//block_size, K)
387
389
A_scales = A_t_scales .transpose (- 2 , - 1 )
388
390
389
391
# grad_B_t = scaled grouped mm of (N,M) @ (M,K) = (E,N,K)
390
392
grad_B = _emulated_mxfp8_scaled_grouped_mm_2d_2d (
391
- grad_out_t_mx ,
393
+ grad_out_t_data ,
392
394
grad_out_t_scales ,
393
- A_mx ,
395
+ A_data ,
394
396
A_scales ,
395
397
offs = offs ,
396
398
)
@@ -402,64 +404,68 @@ def backward(ctx, grad_out: torch.Tensor):
402
404
403
405
404
406
def _emulated_mxfp8_scaled_grouped_mm_2d_3d (
405
- A_mx : torch .Tensor ,
407
+ A_data : torch .Tensor ,
406
408
A_scale : torch .Tensor ,
407
- B_mx : torch .Tensor ,
409
+ B_data : torch .Tensor ,
408
410
B_scale : torch .Tensor ,
409
411
offs : Optional [torch .Tensor ] = None ,
410
412
out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
411
413
block_size : int = 32 ,
412
414
) -> torch .Tensor :
413
- assert A_mx .ndim == 2 , f"A must be 2D, got { A_mx .ndim } "
414
- assert B_mx .ndim == 3 , f"B must be 3D, got { B_mx .ndim } "
415
- assert A_scale .shape [0 ] == A_mx .shape [0 ], (
416
- f"A_scale must have same M dim as A_mx , got A={ A_mx .shape } and A_scale={ A_scale .shape } "
415
+ assert A_data .ndim == 2 , f"A must be 2D, got { A_data .ndim } "
416
+ assert B_data .ndim == 3 , f"B must be 3D, got { B_data .ndim } "
417
+ assert A_scale .shape [0 ] == A_data .shape [0 ], (
418
+ f"A_scale must have same M dim as A_data , got A={ A_data .shape } and A_scale={ A_scale .shape } "
417
419
)
418
- assert A_scale .shape [1 ] == A_mx .shape [1 ] // block_size , (
419
- f"A_scale dim1 should be size K//block_size, got A={ A_mx .shape } and A_scale={ A_scale .shape } "
420
+ assert A_scale .shape [1 ] == A_data .shape [1 ] // block_size , (
421
+ f"A_scale dim1 should be size K//block_size, got A={ A_data .shape } and A_scale={ A_scale .shape } "
420
422
)
421
- assert B_scale .shape [0 ] == B_mx .shape [0 ], (
422
- f"B_scale must have same E dim as B_mx , got B={ B_mx .shape } and B_scale={ B_scale .shape } "
423
+ assert B_scale .shape [0 ] == B_data .shape [0 ], (
424
+ f"B_scale must have same E dim as B_data , got B={ B_data .shape } and B_scale={ B_scale .shape } "
423
425
)
424
- assert B_scale .shape [1 ] == B_mx .shape [1 ], (
425
- f"B_scale must have same N dim as B_mx , got B={ B_mx .shape } and B_scale={ B_scale .shape } "
426
+ assert B_scale .shape [1 ] == B_data .shape [1 ], (
427
+ f"B_scale must have same N dim as B_data , got B={ B_data .shape } and B_scale={ B_scale .shape } "
426
428
)
427
- assert B_scale .shape [2 ] == B_mx .shape [2 ] // block_size , (
428
- f"B_scale dim2 should be size K//block_size, got B={ B_mx .shape } and B_scale={ B_scale .shape } "
429
+ assert B_scale .shape [2 ] == B_data .shape [2 ] // block_size , (
430
+ f"B_scale dim2 should be size K//block_size, got B={ B_data .shape } and B_scale={ B_scale .shape } "
429
431
)
430
432
431
433
# Dequantize input
432
- # A_mx shape: (M, K)
434
+ # A_data shape: (M, K)
433
435
# A_scale shape: (M, K//block_size)
434
- A_orig_shape = A_mx .shape
436
+ A_orig_shape = A_data .shape
435
437
436
438
# Reshape to be able to do per-scaling group multiplication
437
- # A_mx shape: (M, K//block_size, block_size)
439
+ # A_data shape: (M, K//block_size, block_size)
438
440
# A_scale shape: (M, K//block_size, 1)
439
- A_mx = A_mx .reshape (* A_mx .shape [:- 1 ], A_mx .shape [- 1 ] // block_size , block_size )
441
+ A_data = A_data .reshape (
442
+ * A_data .shape [:- 1 ], A_data .shape [- 1 ] // block_size , block_size
443
+ )
440
444
A_scale = A_scale .unsqueeze (- 1 )
441
445
442
446
# Rescale and cast to bfloat16
443
- A = A_mx .to (torch .bfloat16 ) * A_scale .to (torch .bfloat16 )
447
+ A = A_data .to (torch .bfloat16 ) * A_scale .to (torch .bfloat16 )
444
448
445
449
# Reshape back to original shape
446
450
# A shape: (M, K)
447
451
A = A .reshape (A_orig_shape )
448
452
449
453
# Dequantize weights
450
454
# Tranpose to get block_size on rightmost dim
451
- # B_mx shape: (E, N, K)
455
+ # B_data shape: (E, N, K)
452
456
# B_scale shape: (E, N, K//block_size)
453
- E , N , K = B_mx .shape
457
+ E , N , K = B_data .shape
454
458
455
459
# Reshape to be able to do per-scaling group multiplication
456
- # B_mx shape: (E, N, K//block_size, block_size)
460
+ # B_data shape: (E, N, K//block_size, block_size)
457
461
# B_scale shape: (E, N, K//block_size, 1)
458
- B_mx = B_mx .reshape (* B_mx .shape [:- 1 ], B_mx .shape [- 1 ] // block_size , block_size )
462
+ B_data = B_data .reshape (
463
+ * B_data .shape [:- 1 ], B_data .shape [- 1 ] // block_size , block_size
464
+ )
459
465
B_scale = B_scale .unsqueeze (- 1 )
460
466
461
467
# Rescale and cast to bfloat16
462
- B = B_mx .to (torch .bfloat16 ) * B_scale .to (torch .bfloat16 )
468
+ B = B_data .to (torch .bfloat16 ) * B_scale .to (torch .bfloat16 )
463
469
464
470
# Reshape back to original shape
465
471
# B shape: (E, K, N)
@@ -471,27 +477,27 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_3d(
471
477
472
478
473
479
def _emulated_mxfp8_scaled_grouped_mm_2d_2d (
474
- A_mx : torch .Tensor , # (M, K)
480
+ A_data : torch .Tensor , # (M, K)
475
481
A_scale : torch .Tensor , # (M, K//block_size)
476
- B_mx : torch .Tensor , # (K, N)
482
+ B_data : torch .Tensor , # (K, N)
477
483
B_scale : torch .Tensor , # (K//block_size, N)
478
484
offs : torch .Tensor ,
479
485
out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
480
486
block_size : int = 32 ,
481
487
) -> torch .Tensor :
482
- assert A_mx .ndim == 2 , "A must be 2D"
483
- assert B_mx .ndim == 2 , "B must be 2D"
488
+ assert A_data .ndim == 2 , "A must be 2D"
489
+ assert B_data .ndim == 2 , "B must be 2D"
484
490
A = torch .zeros (
485
- A_mx .shape ,
491
+ A_data .shape ,
486
492
dtype = torch .bfloat16 ,
487
- device = A_mx .device ,
488
- requires_grad = A_mx .requires_grad ,
493
+ device = A_data .device ,
494
+ requires_grad = A_data .requires_grad ,
489
495
)
490
496
B = torch .zeros (
491
- B_mx .shape ,
497
+ B_data .shape ,
492
498
dtype = torch .bfloat16 ,
493
- device = B_mx .device ,
494
- requires_grad = B_mx .requires_grad ,
499
+ device = B_data .device ,
500
+ requires_grad = B_data .requires_grad ,
495
501
)
496
502
497
503
# Dequantize input per each scaling group
@@ -507,7 +513,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
507
513
# -- Dequantize A tensor
508
514
# A_group shape: (M, group_size)
509
515
# A_scale shape: (M, group_size//block_size)
510
- A_group = A_mx [:, group_start_idx :group_end_idx ]
516
+ A_group = A_data [:, group_start_idx :group_end_idx ]
511
517
A_group_shape = A_group .shape
512
518
513
519
# Get scales for this group.
@@ -532,7 +538,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
532
538
533
539
# -- Dequantize B tensor
534
540
# B_group shape is (group_size, N)
535
- B_group = B_mx [group_start_idx :group_end_idx , :]
541
+ B_group = B_data [group_start_idx :group_end_idx , :]
536
542
B_group_shape = B_group .shape
537
543
538
544
# Scales shape is (group_size//block_size, N)
0 commit comments