Skip to content

Commit c66c5c0

Browse files
[mxfp8 moe training] add per group blocked scale kernels
stack-info: PR: #2886, branch: danielvegamyhre/stack/62
1 parent 16f5bef commit c66c5c0

File tree

4 files changed

+287
-1
lines changed

4 files changed

+287
-1
lines changed

test/prototype/moe_training/test_kernels.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import pytest
88
import torch
99

10+
from torchao.prototype.mx_formats.utils import to_blocked_per_group_2d
11+
1012
# We need to skip before doing any imports which would use triton, since
1113
# triton won't be available on CPU builds
1214
if not (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9):
@@ -20,12 +22,17 @@
2022
triton_fp8_per_group_colwise_scales,
2123
triton_fp8_per_group_rowwise_scales,
2224
)
25+
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
26+
triton_mx_block_rearrange_per_group,
27+
)
2328
from torchao.prototype.moe_training.utils import (
2429
_is_column_major,
30+
generate_jagged_offs,
2531
torch_to_3d_rowwise_float8_transpose_rhs,
2632
torch_to_float8_per_group_colwise,
2733
torch_to_float8_per_group_rowwise,
2834
)
35+
from torchao.prototype.mx_formats.mx_tensor import to_mx
2936
from torchao.testing.utils import skip_if_rocm
3037

3138

@@ -118,3 +125,28 @@ def test_fp8_rowwise_3d_transpose_rhs(round_scales_to_power_of_2: bool):
118125
assert ref_fp8.shape == triton_fp8.shape, "output shapes not equal"
119126
assert ref_fp8.stride() == triton_fp8.stride(), "output strides not equal"
120127
assert torch.allclose(ref_fp8, triton_fp8, rtol=0, atol=0), "fp8 data not equal"
128+
129+
130+
@skip_if_rocm("ROCm enablement in progress")
131+
@pytest.mark.parametrize("round_scales_to_power_of_2", [True, False])
132+
@pytest.mark.parametrize("m,k,n_groups", [(256, 256, 4)])
133+
def test_mxfp8_per_group_blocked_scales_2d(
134+
m: int, k: int, n_groups: int, round_scales_to_power_of_2: bool
135+
):
136+
device = "cuda"
137+
block_size = 32
138+
input_data = torch.randn(m, k, device=device)
139+
e8m0_scales, _ = to_mx(
140+
input_data, elem_dtype=torch.float8_e4m3fn, block_size=block_size
141+
)
142+
offs = generate_jagged_offs(n_groups, m, multiple_of=block_size, device=device)
143+
144+
# torch reference
145+
ref_out = to_blocked_per_group_2d(e8m0_scales, offs, m, k, block_size=block_size)
146+
147+
# triton kernel
148+
triton_out = triton_mx_block_rearrange_per_group(e8m0_scales, offs)
149+
150+
assert torch.testing.allclose(ref_out, triton_out, atol=0, rtol=0), (
151+
"blocked scales not equal"
152+
)

torchao/prototype/moe_training/kernels/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@
77
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
88
triton_fp8_per_group_rowwise_scales as triton_fp8_per_group_rowwise_scales,
99
)
10-
from torchao.prototype.moe_training.kernels.mxfp8 import (
10+
from torchao.prototype.moe_training.kernels.mxfp8_gemms import (
1111
fbgemm_mxfp8_grouped_mm_2d_3d as fbgemm_mxfp8_grouped_mm_2d_3d,
1212
)
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
from torch import Tensor
5+
from torch.library import triton_op, wrap_triton
6+
7+
from torchao.utils import ceil_div
8+
9+
10+
def to_blocked_per_group_2d(
11+
x_scales: Tensor, group_offs: Tensor, Mg: int, K: int, block_size: int = 32
12+
) -> Tensor:
13+
"""
14+
Convert scales to blocked format for a 2D tensor (input activations / token groups)
15+
16+
Args:
17+
x_scales: Tensor with per group scales in blocked format concatenated into one tensor.
18+
group_offs: Tensor of shape (num_groups,) which contains the end index of each group along the Mg dimension.
19+
Mg: total size of all groups summed together
20+
K: K dim size
21+
22+
Returns:
23+
blocked_scales: Tensor
24+
start_row_after_padding: Tensor of shape (num_groups,) which contains the start row after padding for each group.
25+
"""
26+
from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import _to_blocked
27+
28+
assert x_scales.ndim == 2, "x_scales must be 2D"
29+
assert block_size == 32, "Only block_size=32 is supported for now"
30+
blocked_scales_list = []
31+
start_row_after_padding_list = [0]
32+
group_start_idx = 0
33+
for i, group_end_idx in enumerate(group_offs.tolist()):
34+
group_size = group_end_idx - group_start_idx
35+
prev_start_row_after_padding = start_row_after_padding_list[i]
36+
if group_size == 0:
37+
start_row_after_padding_list.append(prev_start_row_after_padding)
38+
continue
39+
40+
# Convert group scales to blocked format
41+
group_scales = x_scales[group_start_idx:group_end_idx]
42+
group_scales_blocked = _to_blocked(group_scales)
43+
blocked_scales_list.append(group_scales_blocked)
44+
45+
# Calculate the start row after padding
46+
scaling_groups_per_row = K // block_size
47+
rows_for_group = group_scales_blocked.numel() // scaling_groups_per_row
48+
new_start_row = prev_start_row_after_padding + rows_for_group
49+
start_row_after_padding_list.append(new_start_row)
50+
51+
# Update next group start index
52+
group_start_idx = group_end_idx
53+
54+
blocked_scales = torch.cat(blocked_scales_list, dim=0).contiguous()
55+
blocked_scales = blocked_scales.reshape(-1, K // 32)
56+
start_row_after_padding = torch.tensor(
57+
start_row_after_padding_list, device=x_scales.device, dtype=torch.int64
58+
)
59+
return blocked_scales, start_row_after_padding
60+
61+
62+
def to_blocked_per_group_3d(weight_scales: Tensor) -> Tensor:
63+
"""
64+
Convert scales to blocked format for each group for a 3D tensor (expert weights)
65+
66+
Args:
67+
scales: Tensor of shape (E, N, K//block_size)
68+
group_offs: Tensor of shape (num_groups,) which contains the end index of each group along the
69+
"""
70+
from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import _to_blocked
71+
72+
blocked_scales_list = []
73+
num_groups = weight_scales.shape[0]
74+
for i in range(num_groups):
75+
group_scales = weight_scales[i]
76+
group_scales_blocked = _to_blocked(group_scales)
77+
blocked_scales_list.append(group_scales_blocked)
78+
weight_scales_blocked = torch.stack(blocked_scales_list, dim=0).contiguous()
79+
weight_scales_blocked = weight_scales_blocked.reshape(num_groups, -1)
80+
return weight_scales_blocked
81+
82+
83+
def compute_per_group_blocked_scale_offsets(offsets: torch.Tensor):
84+
"""
85+
Rounds each integer in a 1D PyTorch tensor up to the nearest multiple of 128.
86+
87+
Args:
88+
offsets: A 1D PyTorch tensor of integers in ascending sorted order, representing the end index of each group along the Mg dimension.
89+
90+
Returns:
91+
- group_sizes: A 1D PyTorch tensor of integers representing the size of each group.
92+
- starting_row_after_padding: 1D integer tensor representing the starting row after padding each to blocked format.
93+
"""
94+
# Calculate group sizes
95+
zero = torch.tensor([0], dtype=offsets.dtype, device=offsets.device)
96+
group_sizes = torch.diff(offsets, prepend=zero).to(torch.int64)
97+
98+
# Round each group size up to the nearest multiple of 128
99+
rounded_group_sizes = ceil_div(group_sizes, 128) * 128
100+
101+
# Calculate the starting row after padding for each group
102+
starting_row_after_padding = torch.cumsum(rounded_group_sizes, dim=0)
103+
return group_sizes, starting_row_after_padding
104+
105+
106+
@triton_op("torchao::triton_mx_block_rearrange_per_group", mutates_args=())
107+
def triton_mx_block_rearrange_per_group(
108+
scales_tensor: torch.Tensor,
109+
offsets: torch.Tensor,
110+
) -> torch.Tensor:
111+
"""
112+
Rearranges an E8M0 tensor scale to block-scaled swizzle format.
113+
114+
This format is suitable for Tmem as described in NVIDIA documentation:
115+
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
116+
117+
Args:
118+
scales_tensor: Input tensor containing e8m0 scales for each logical group of a target tensor.
119+
blocked_scale_group_offsets: Tensor of shape (num_groups,) which contains the pre-computed start index of each group along the M dimension.
120+
121+
Returns:
122+
Rearranged tensor in block-scaled swizzle format
123+
"""
124+
assert scales_tensor.element_size() == 1, (
125+
"Expected element size to be 1 byte (8 bits)"
126+
)
127+
_, output_scales_group_offsets = compute_per_group_blocked_scale_offsets(offsets)
128+
rows, cols = scales_tensor.shape
129+
130+
# Calculate blocks needed
131+
num_groups = output_scales_group_offsets.numel()
132+
padded_rows = output_scales_group_offsets[
133+
-1
134+
] # Final offset is the total number of rows in the tensor
135+
num_col_blocks = ceil_div(cols, 4)
136+
padded_cols = num_col_blocks * 4
137+
out = scales_tensor.new_empty((padded_rows, padded_cols))
138+
139+
# We probably want handle multiple blocks per tile but for now keep it simple
140+
BLOCK_ROWS, BLOCK_COLS = 128, 4
141+
142+
# Output block stride for the rearranged format
143+
output_stride_per_row_of_blocks = (
144+
BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS)
145+
)
146+
147+
# We parallelize per group and per col block.
148+
# Rows per group is variable so we just loop through row blocks per group, per col block.
149+
grid = lambda META: (
150+
num_groups,
151+
num_col_blocks,
152+
)
153+
154+
wrap_triton(triton_scale_swizzle_per_group)[grid](
155+
# Input scales
156+
scales_tensor.view(torch.uint8),
157+
scales_tensor.stride(0),
158+
scales_tensor.stride(1),
159+
rows,
160+
cols,
161+
num_groups,
162+
# Original offsets (to read from)
163+
offsets,
164+
# Output scales tensor and group offsets after padding (to write to)
165+
out.view(torch.uint8),
166+
output_scales_group_offsets,
167+
output_stride_per_row_of_blocks,
168+
BLOCK_ROWS=BLOCK_ROWS,
169+
BLOCK_COLS=BLOCK_COLS,
170+
)
171+
return out
172+
173+
174+
@triton.jit
175+
def triton_scale_swizzle_per_group(
176+
scales_ptr, # (M, K//block_size)
177+
scales_stride_dim0,
178+
scales_stride_dim1,
179+
scale_rows,
180+
scale_cols,
181+
num_groups,
182+
orig_offsets, # (num_groups,)
183+
output_scales_ptr, # (rows + num_groups * 128, tl.cdiv(K, 4) * 4)
184+
output_scales_group_offsets, # (num_groups,)
185+
output_stride_per_row_of_blocks,
186+
BLOCK_ROWS: tl.constexpr,
187+
BLOCK_COLS: tl.constexpr,
188+
):
189+
group_pid = tl.program_id(0)
190+
block_col_pid = tl.program_id(1)
191+
192+
row_offs = tl.arange(0, BLOCK_ROWS)[:, None]
193+
col_offs = tl.arange(0, BLOCK_COLS)[None, :]
194+
195+
# Row range for this group
196+
input_start_row = tl.load(orig_offsets + group_pid - 1, mask=group_pid > 0, other=0)
197+
input_end_row = tl.load(
198+
orig_offsets + group_pid, mask=group_pid < num_groups, other=0
199+
)
200+
201+
# Base offset in the output scales tensor we will write to
202+
output_row_start_offset = tl.load(
203+
output_scales_group_offsets + group_pid, mask=group_pid < num_groups, other=0
204+
)
205+
output_row_end_offset = tl.load(
206+
output_scales_group_offsets + group_pid + 1,
207+
mask=group_pid < num_groups,
208+
other=0,
209+
)
210+
211+
# For this group and col block, we iterate through blocks, reading (BLOCK_ROWS, BLOCK_COLS) from the input scales.
212+
# We need to track how many row blocks we iterated through.
213+
block_row_id = 0
214+
for row_off in tl.range(
215+
input_start_row, tl.cdiv(input_end_row, BLOCK_ROWS) * BLOCK_ROWS, BLOCK_ROWS
216+
):
217+
# Read block of input scales
218+
block_row_offs = row_off + row_offs[:, None]
219+
block_col_offs = block_col_pid * BLOCK_COLS + col_offs[None, :]
220+
block_offs = block_row_offs + block_col_offs
221+
mask = (block_row_offs < input_end_row) & (block_col_offs < scale_cols)
222+
input_scales = tl.load(scales_ptr + block_offs, mask=mask, other=0.0)
223+
224+
# Calculate destination indices for each row and col
225+
r_div_32 = scale_rows // 32
226+
r_mod_32 = scale_rows % 32
227+
228+
# Rearrange to (32, 4, 4) then to final (32, 16) coordinates
229+
dest_indices = r_mod_32 * 16 + r_div_32 * 4 + scale_cols
230+
231+
# Flatten
232+
dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS))
233+
scales_flat = tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS))
234+
235+
# Calculate block offset using provided output block stride
236+
output_block_offsets = (
237+
output_row_start_offset
238+
+ (block_row_id * output_stride_per_row_of_blocks)
239+
+ (block_col_pid * BLOCK_COLS)
240+
)
241+
242+
tl.store(
243+
output_scales_ptr + output_block_offsets + dest_indices_flat,
244+
scales_flat,
245+
mask=output_block_offsets < output_row_end_offset,
246+
)
247+
248+
# Update row block id to next block
249+
block_row_id += 1

torchao/prototype/moe_training/kernels/mxfp8.py renamed to torchao/prototype/moe_training/kernels/mxfp8_gemms.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
"If errors persist, please file a bug report."
2020
)
2121

22+
DEBUG = False
23+
2224

2325
@torch.library.custom_op("torchao::fbgemm_mxfp8_grouped_mm_2d_3d", mutates_args={})
2426
def fbgemm_mxfp8_grouped_mm_2d_3d(
@@ -108,6 +110,9 @@ def _log_inputs(
108110
group_sizes: torch.Tensor,
109111
starting_row_after_padding: torch.Tensor,
110112
):
113+
if not DEBUG:
114+
return
115+
111116
logger.info(f"offs: {offs}, dtype: {offs.dtype}")
112117
logger.info(
113118
f"A_fp8.shape: {A_fp8.shape}, stride: {A_fp8.stride()}, dtype: {A_fp8.dtype}"

0 commit comments

Comments
 (0)