Skip to content

Commit fb1628c

Browse files
[mxfp8 moe training] refactor all var names with suffix _mx to _data for clarity
stack-info: PR: #2879, branch: danielvegamyhre/stack/60
1 parent 15a6de6 commit fb1628c

File tree

1 file changed

+81
-75
lines changed

1 file changed

+81
-75
lines changed

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 81 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def forward(
122122
round_scales_to_power_of_2=True,
123123
)
124124
A_scaled = A.to(torch.float32) * A_scales
125-
A_fp8_row_major = to_fp8_saturated(A_scaled, torch.float8_e4m3fn)
125+
A_data_row_major = to_fp8_saturated(A_scaled, torch.float8_e4m3fn)
126126

127127
# Convert B to float8, column-major for right operand of grouped GEMM.
128128
# B_t shape: (E, K, N)
@@ -136,18 +136,18 @@ def forward(
136136
round_scales_to_power_of_2=True,
137137
)
138138
B_t_scaled = B_t.to(torch.float32) * B_t_scales
139-
B_t_fp8_col_major = to_fp8_saturated(B_t_scaled, torch.float8_e4m3fn)
139+
B_t_data_col_major = to_fp8_saturated(B_t_scaled, torch.float8_e4m3fn)
140140

141141
# Store what we need for backward.
142142
ctx.save_for_backward(A, B_t, offs)
143143
ctx.out_dtype = out_dtype
144144

145145
# Perform scaled grouped GEMM and return result.
146146
# output shape: scaled grouped mm of (M,K) @ (B,K,N) = (M,N)
147-
assert not _is_column_major(A_fp8_row_major), (
147+
assert not _is_column_major(A_data_row_major), (
148148
"A must be row-major for output = A @ B"
149149
)
150-
assert _is_column_major(B_t_fp8_col_major), (
150+
assert _is_column_major(B_t_data_col_major), (
151151
"B must be column-major for output = A @ B"
152152
)
153153

@@ -157,8 +157,8 @@ def forward(
157157
A_scales = A_scales.squeeze(-1)
158158
B_t_scales = B_t_scales.squeeze(1)
159159
return torch._scaled_grouped_mm(
160-
A_fp8_row_major,
161-
B_t_fp8_col_major,
160+
A_data_row_major,
161+
B_t_data_col_major,
162162
A_scales.reciprocal(), # Reciprocals are needed for rescaling the output.
163163
B_t_scales.reciprocal(),
164164
offs,
@@ -184,13 +184,13 @@ def backward(ctx, grad_output: torch.Tensor):
184184
round_scales_to_power_of_2=True,
185185
)
186186
grad_output_scaled = grad_output.to(torch.float32) * grad_output_scales
187-
grad_output_fp8_row_major = to_fp8_saturated(
187+
grad_output_data_row_major = to_fp8_saturated(
188188
grad_output_scaled, torch.float8_e4m3fn
189189
)
190190

191191
# Compute B fp8 column-major for right operand of grouped GEMM:
192192
# grad_A = grad_output @ B.
193-
B_fp8_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs(
193+
B_data_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs(
194194
B_t._data if hasattr(B_t, "_data") else B_t,
195195
output_dtype=torch.float8_e4m3fn,
196196
round_scales_to_power_of_2=True,
@@ -199,10 +199,10 @@ def backward(ctx, grad_output: torch.Tensor):
199199
# Compute grad_A.
200200
# grad_A = grad_output @ B
201201
# grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
202-
assert not _is_column_major(grad_output_fp8_row_major), (
202+
assert not _is_column_major(grad_output_data_row_major), (
203203
"grad_output must be row-major for grad_A = grad_output @ B"
204204
)
205-
assert _is_column_major(B_fp8_col_major), (
205+
assert _is_column_major(B_data_col_major), (
206206
"B must be column-major for grad_A = grad_output @ B"
207207
)
208208

@@ -212,8 +212,8 @@ def backward(ctx, grad_output: torch.Tensor):
212212
grad_output_scales = grad_output_scales.squeeze(-1)
213213
B_scales = B_scales.squeeze(1)
214214
grad_A = torch._scaled_grouped_mm(
215-
grad_output_fp8_row_major,
216-
B_fp8_col_major,
215+
grad_output_data_row_major,
216+
B_data_col_major,
217217
grad_output_scales.reciprocal(),
218218
B_scales.reciprocal(),
219219
offs,
@@ -227,18 +227,18 @@ def backward(ctx, grad_output: torch.Tensor):
227227
# Convert transpose of grad_output to float8, row-major for left operand of grouped GEMM
228228
# needed for grad_B: grad_output_t @ A
229229
# Use transpose method to avoid uncoalesced memory accesses.
230-
grad_out_fp8_colwise, grad_out_scales = triton_fp8_per_group_colwise_scales(
230+
grad_out_data_colwise, grad_out_scales = triton_fp8_per_group_colwise_scales(
231231
grad_output.t()
232232
.contiguous()
233233
.t(), # Quantization is over 2x faster when input is col major, even with this transformation
234234
offs,
235235
torch.float8_e4m3fn,
236236
round_scales_to_power_of_2=True,
237237
)
238-
grad_output_t_fp8_row_major = grad_out_fp8_colwise.t()
238+
grad_output_t_data_row_major = grad_out_data_colwise.t()
239239
grad_output_t_scales = grad_out_scales.t()
240240

241-
A_fp8_col_major, A_scales = triton_fp8_per_group_colwise_scales(
241+
A_data_col_major, A_scales = triton_fp8_per_group_colwise_scales(
242242
A.t()
243243
.contiguous()
244244
.t(), # Quantization is over 2x faster when input is col major, even with this transformation
@@ -249,19 +249,19 @@ def backward(ctx, grad_output: torch.Tensor):
249249

250250
# Compute grad_B = grad_output_t @ A.
251251
# grad_B = grad_output_t @ A
252-
assert not _is_column_major(grad_output_t_fp8_row_major), (
252+
assert not _is_column_major(grad_output_t_data_row_major), (
253253
"grad_output_t must be row-major for grad_B = grad_output_t @ A"
254254
)
255-
assert _is_column_major(A_fp8_col_major), (
255+
assert _is_column_major(A_data_col_major), (
256256
"A must be column-major for grad_B = grad_output_t @ A"
257257
)
258258

259259
# Per-token group scales computed via triton kernels above do not have
260260
# the empty dim like the scales computed via tensor_to_scale, so we need
261261
# don't need to squeeze here.
262262
grad_B = torch._scaled_grouped_mm(
263-
grad_output_t_fp8_row_major,
264-
A_fp8_col_major,
263+
grad_output_t_data_row_major,
264+
A_data_col_major,
265265
grad_output_t_scales.reciprocal(),
266266
A_scales.reciprocal(),
267267
offs,
@@ -295,13 +295,15 @@ def forward(
295295
ctx.out_dtype = out_dtype
296296
ctx.emulated = emulated
297297

298-
# A_mx shape: (M, K)
298+
# A_data shape: (M, K)
299299
# A_scale shape: (M, K//block_size)
300-
A_scale, A_mx = to_mx(A, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
300+
A_scale, A_data = to_mx(
301+
A, elem_dtype=torch.float8_e4m3fn, block_size=block_size
302+
)
301303

302-
# B_mx shape: (E, N, K)
304+
# B_data shape: (E, N, K)
303305
# B_scale shape: (E, N, K//block_size)
304-
B_scales, B_mx = to_mx(
306+
B_scales, B_data = to_mx(
305307
B_t.transpose(-2, -1),
306308
elem_dtype=torch.float8_e4m3fn,
307309
block_size=block_size,
@@ -315,9 +317,9 @@ def forward(
315317
else fbgemm_mxfp8_grouped_mm_2d_3d
316318
)
317319
out = mxfp8_2d_3d_grouped_mm(
318-
A_mx,
320+
A_data,
319321
A_scale,
320-
B_mx,
322+
B_data,
321323
B_scales,
322324
offs=offs,
323325
block_size=block_size,
@@ -332,15 +334,15 @@ def backward(ctx, grad_out: torch.Tensor):
332334
out_dtype = ctx.out_dtype
333335
emulated = ctx.emulated
334336

335-
# grad_out_mx shape: (M, N)
337+
# grad_out_data shape: (M, N)
336338
# grad_out_scale shape: (M, N//block_size)
337-
grad_out_scale, grad_out_mx = to_mx(
339+
grad_out_scale, grad_out_data = to_mx(
338340
grad_out, elem_dtype=torch.float8_e4m3fn, block_size=block_size
339341
)
340342

341-
# B_mx shape: (E, K, N)
343+
# B_data shape: (E, K, N)
342344
# B_scale shape: (E, K, N//block_size)
343-
B_scales, B_mx = to_mx(
345+
B_scales, B_data = to_mx(
344346
# TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
345347
B_t.contiguous(),
346348
elem_dtype=torch.float8_e4m3fn,
@@ -354,43 +356,43 @@ def backward(ctx, grad_out: torch.Tensor):
354356
else fbgemm_mxfp8_grouped_mm_2d_3d
355357
)
356358
grad_A = mxfp8_2d_3d_grouped_mm(
357-
grad_out_mx,
359+
grad_out_data,
358360
grad_out_scale,
359-
B_mx,
361+
B_data,
360362
B_scales,
361363
offs=offs,
362364
out_dtype=out_dtype,
363365
)
364366

365-
# grad_out_t_mx shape: (N, M)
367+
# grad_out_t_data shape: (N, M)
366368
# grad_out_t_scales shape: (N, M//block_size)
367-
grad_out_t_scales, grad_out_t_mx = to_mx(
369+
grad_out_t_scales, grad_out_t_data = to_mx(
368370
# TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
369371
grad_out.transpose(-2, -1).contiguous(),
370372
elem_dtype=torch.float8_e4m3fn,
371373
block_size=block_size,
372374
)
373375

374376
# Transpose A so we can scale along the M dimension, then un-transpose.
375-
# A_t_mx shape: (K, M)
377+
# A_t_data shape: (K, M)
376378
# A_t_scales shape: (K, M//block_size)
377-
A_t_scales, A_t_mx = to_mx(
379+
A_t_scales, A_t_data = to_mx(
378380
A.transpose(-2, -1).contiguous(),
379381
elem_dtype=torch.float8_e4m3fn,
380382
block_size=block_size,
381383
)
382384

383-
# A_mx shape = (M, K)
384-
A_mx = A_t_mx.transpose(-2, -1)
385+
# A_data shape = (M, K)
386+
A_data = A_t_data.transpose(-2, -1)
385387

386388
# A_scales shape = (M//block_size, K)
387389
A_scales = A_t_scales.transpose(-2, -1)
388390

389391
# grad_B_t = scaled grouped mm of (N,M) @ (M,K) = (E,N,K)
390392
grad_B = _emulated_mxfp8_scaled_grouped_mm_2d_2d(
391-
grad_out_t_mx,
393+
grad_out_t_data,
392394
grad_out_t_scales,
393-
A_mx,
395+
A_data,
394396
A_scales,
395397
offs=offs,
396398
)
@@ -402,64 +404,68 @@ def backward(ctx, grad_out: torch.Tensor):
402404

403405

404406
def _emulated_mxfp8_scaled_grouped_mm_2d_3d(
405-
A_mx: torch.Tensor,
407+
A_data: torch.Tensor,
406408
A_scale: torch.Tensor,
407-
B_mx: torch.Tensor,
409+
B_data: torch.Tensor,
408410
B_scale: torch.Tensor,
409411
offs: Optional[torch.Tensor] = None,
410412
out_dtype: Optional[torch.dtype] = torch.bfloat16,
411413
block_size: int = 32,
412414
) -> torch.Tensor:
413-
assert A_mx.ndim == 2, f"A must be 2D, got {A_mx.ndim}"
414-
assert B_mx.ndim == 3, f"B must be 3D, got {B_mx.ndim}"
415-
assert A_scale.shape[0] == A_mx.shape[0], (
416-
f"A_scale must have same M dim as A_mx, got A={A_mx.shape} and A_scale={A_scale.shape}"
415+
assert A_data.ndim == 2, f"A must be 2D, got {A_data.ndim}"
416+
assert B_data.ndim == 3, f"B must be 3D, got {B_data.ndim}"
417+
assert A_scale.shape[0] == A_data.shape[0], (
418+
f"A_scale must have same M dim as A_data, got A={A_data.shape} and A_scale={A_scale.shape}"
417419
)
418-
assert A_scale.shape[1] == A_mx.shape[1] // block_size, (
419-
f"A_scale dim1 should be size K//block_size, got A={A_mx.shape} and A_scale={A_scale.shape}"
420+
assert A_scale.shape[1] == A_data.shape[1] // block_size, (
421+
f"A_scale dim1 should be size K//block_size, got A={A_data.shape} and A_scale={A_scale.shape}"
420422
)
421-
assert B_scale.shape[0] == B_mx.shape[0], (
422-
f"B_scale must have same E dim as B_mx, got B={B_mx.shape} and B_scale={B_scale.shape}"
423+
assert B_scale.shape[0] == B_data.shape[0], (
424+
f"B_scale must have same E dim as B_data, got B={B_data.shape} and B_scale={B_scale.shape}"
423425
)
424-
assert B_scale.shape[1] == B_mx.shape[1], (
425-
f"B_scale must have same N dim as B_mx, got B={B_mx.shape} and B_scale={B_scale.shape}"
426+
assert B_scale.shape[1] == B_data.shape[1], (
427+
f"B_scale must have same N dim as B_data, got B={B_data.shape} and B_scale={B_scale.shape}"
426428
)
427-
assert B_scale.shape[2] == B_mx.shape[2] // block_size, (
428-
f"B_scale dim2 should be size K//block_size, got B={B_mx.shape} and B_scale={B_scale.shape}"
429+
assert B_scale.shape[2] == B_data.shape[2] // block_size, (
430+
f"B_scale dim2 should be size K//block_size, got B={B_data.shape} and B_scale={B_scale.shape}"
429431
)
430432

431433
# Dequantize input
432-
# A_mx shape: (M, K)
434+
# A_data shape: (M, K)
433435
# A_scale shape: (M, K//block_size)
434-
A_orig_shape = A_mx.shape
436+
A_orig_shape = A_data.shape
435437

436438
# Reshape to be able to do per-scaling group multiplication
437-
# A_mx shape: (M, K//block_size, block_size)
439+
# A_data shape: (M, K//block_size, block_size)
438440
# A_scale shape: (M, K//block_size, 1)
439-
A_mx = A_mx.reshape(*A_mx.shape[:-1], A_mx.shape[-1] // block_size, block_size)
441+
A_data = A_data.reshape(
442+
*A_data.shape[:-1], A_data.shape[-1] // block_size, block_size
443+
)
440444
A_scale = A_scale.unsqueeze(-1)
441445

442446
# Rescale and cast to bfloat16
443-
A = A_mx.to(torch.bfloat16) * A_scale.to(torch.bfloat16)
447+
A = A_data.to(torch.bfloat16) * A_scale.to(torch.bfloat16)
444448

445449
# Reshape back to original shape
446450
# A shape: (M, K)
447451
A = A.reshape(A_orig_shape)
448452

449453
# Dequantize weights
450454
# Tranpose to get block_size on rightmost dim
451-
# B_mx shape: (E, N, K)
455+
# B_data shape: (E, N, K)
452456
# B_scale shape: (E, N, K//block_size)
453-
E, N, K = B_mx.shape
457+
E, N, K = B_data.shape
454458

455459
# Reshape to be able to do per-scaling group multiplication
456-
# B_mx shape: (E, N, K//block_size, block_size)
460+
# B_data shape: (E, N, K//block_size, block_size)
457461
# B_scale shape: (E, N, K//block_size, 1)
458-
B_mx = B_mx.reshape(*B_mx.shape[:-1], B_mx.shape[-1] // block_size, block_size)
462+
B_data = B_data.reshape(
463+
*B_data.shape[:-1], B_data.shape[-1] // block_size, block_size
464+
)
459465
B_scale = B_scale.unsqueeze(-1)
460466

461467
# Rescale and cast to bfloat16
462-
B = B_mx.to(torch.bfloat16) * B_scale.to(torch.bfloat16)
468+
B = B_data.to(torch.bfloat16) * B_scale.to(torch.bfloat16)
463469

464470
# Reshape back to original shape
465471
# B shape: (E, K, N)
@@ -471,27 +477,27 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_3d(
471477

472478

473479
def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
474-
A_mx: torch.Tensor, # (M, K)
480+
A_data: torch.Tensor, # (M, K)
475481
A_scale: torch.Tensor, # (M, K//block_size)
476-
B_mx: torch.Tensor, # (K, N)
482+
B_data: torch.Tensor, # (K, N)
477483
B_scale: torch.Tensor, # (K//block_size, N)
478484
offs: torch.Tensor,
479485
out_dtype: Optional[torch.dtype] = torch.bfloat16,
480486
block_size: int = 32,
481487
) -> torch.Tensor:
482-
assert A_mx.ndim == 2, "A must be 2D"
483-
assert B_mx.ndim == 2, "B must be 2D"
488+
assert A_data.ndim == 2, "A must be 2D"
489+
assert B_data.ndim == 2, "B must be 2D"
484490
A = torch.zeros(
485-
A_mx.shape,
491+
A_data.shape,
486492
dtype=torch.bfloat16,
487-
device=A_mx.device,
488-
requires_grad=A_mx.requires_grad,
493+
device=A_data.device,
494+
requires_grad=A_data.requires_grad,
489495
)
490496
B = torch.zeros(
491-
B_mx.shape,
497+
B_data.shape,
492498
dtype=torch.bfloat16,
493-
device=B_mx.device,
494-
requires_grad=B_mx.requires_grad,
499+
device=B_data.device,
500+
requires_grad=B_data.requires_grad,
495501
)
496502

497503
# Dequantize input per each scaling group
@@ -507,7 +513,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
507513
# -- Dequantize A tensor
508514
# A_group shape: (M, group_size)
509515
# A_scale shape: (M, group_size//block_size)
510-
A_group = A_mx[:, group_start_idx:group_end_idx]
516+
A_group = A_data[:, group_start_idx:group_end_idx]
511517
A_group_shape = A_group.shape
512518

513519
# Get scales for this group.
@@ -532,7 +538,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
532538

533539
# -- Dequantize B tensor
534540
# B_group shape is (group_size, N)
535-
B_group = B_mx[group_start_idx:group_end_idx, :]
541+
B_group = B_data[group_start_idx:group_end_idx, :]
536542
B_group_shape = B_group.shape
537543

538544
# Scales shape is (group_size//block_size, N)

0 commit comments

Comments
 (0)