Skip to content

Commit 75ae9d6

Browse files
[mxfp8 moe training] add per group blocked scale kernels
stack-info: PR: #2886, branch: danielvegamyhre/stack/62
1 parent 16f5bef commit 75ae9d6

File tree

4 files changed

+298
-1
lines changed

4 files changed

+298
-1
lines changed

test/prototype/moe_training/test_kernels.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import pytest
88
import torch
99

10+
from torchao.prototype.mx_formats.utils import to_blocked_per_group_2d
11+
1012
# We need to skip before doing any imports which would use triton, since
1113
# triton won't be available on CPU builds
1214
if not (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9):
@@ -20,12 +22,17 @@
2022
triton_fp8_per_group_colwise_scales,
2123
triton_fp8_per_group_rowwise_scales,
2224
)
25+
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
26+
triton_mx_block_rearrange_per_group,
27+
)
2328
from torchao.prototype.moe_training.utils import (
2429
_is_column_major,
30+
generate_jagged_offs,
2531
torch_to_3d_rowwise_float8_transpose_rhs,
2632
torch_to_float8_per_group_colwise,
2733
torch_to_float8_per_group_rowwise,
2834
)
35+
from torchao.prototype.mx_formats.mx_tensor import to_mx
2936
from torchao.testing.utils import skip_if_rocm
3037

3138

@@ -118,3 +125,35 @@ def test_fp8_rowwise_3d_transpose_rhs(round_scales_to_power_of_2: bool):
118125
assert ref_fp8.shape == triton_fp8.shape, "output shapes not equal"
119126
assert ref_fp8.stride() == triton_fp8.stride(), "output strides not equal"
120127
assert torch.allclose(ref_fp8, triton_fp8, rtol=0, atol=0), "fp8 data not equal"
128+
129+
130+
@skip_if_rocm("ROCm enablement in progress")
131+
@pytest.mark.parametrize("m,k,n_groups", [(256, 256, 4), (16640, 5120, 16)])
132+
def test_mxfp8_per_group_blocked_scales_2d(
133+
m: int,
134+
k: int,
135+
n_groups: int,
136+
):
137+
device = "cuda"
138+
block_size = 32
139+
input_data = torch.randn(m, k, device=device)
140+
e8m0_scales, _ = to_mx(
141+
input_data, elem_dtype=torch.float8_e4m3fn, block_size=block_size
142+
)
143+
offs = generate_jagged_offs(n_groups, m, multiple_of=block_size, device=device)
144+
145+
# torch reference
146+
ref_out_scales, ref_group_offsets = to_blocked_per_group_2d(
147+
e8m0_scales, offs, m, k, block_size=block_size
148+
)
149+
150+
# triton kernel
151+
triton_out_scales, triton_group_offsets = triton_mx_block_rearrange_per_group(
152+
e8m0_scales, offs
153+
)
154+
assert torch.allclose(ref_group_offsets, triton_group_offsets, atol=0, rtol=0), (
155+
"group offsets not equal"
156+
)
157+
assert torch.allclose(ref_out_scales, triton_out_scales, atol=0, rtol=0), (
158+
"blocked scales not equal"
159+
)

torchao/prototype/moe_training/kernels/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@
77
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
88
triton_fp8_per_group_rowwise_scales as triton_fp8_per_group_rowwise_scales,
99
)
10-
from torchao.prototype.moe_training.kernels.mxfp8 import (
10+
from torchao.prototype.moe_training.kernels.mxfp8_gemms import (
1111
fbgemm_mxfp8_grouped_mm_2d_3d as fbgemm_mxfp8_grouped_mm_2d_3d,
1212
)
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
from torch import Tensor
5+
6+
from torchao.utils import ceil_div
7+
8+
9+
def to_blocked_per_group_2d(
10+
x_scales: Tensor, group_offs: Tensor, Mg: int, K: int, block_size: int = 32
11+
) -> Tensor:
12+
"""
13+
Convert scales to blocked format for a 2D tensor (input activations / token groups)
14+
15+
Args:
16+
x_scales: Tensor with per group scales in blocked format concatenated into one tensor.
17+
group_offs: Tensor of shape (num_groups,) which contains the end index of each group along the Mg dimension.
18+
Mg: total size of all groups summed together
19+
K: K dim size
20+
21+
Returns:
22+
blocked_scales: Tensor
23+
start_row_after_padding: Tensor of shape (num_groups,) which contains the start row after padding for each group.
24+
"""
25+
from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import _to_blocked
26+
27+
assert x_scales.ndim == 2, "x_scales must be 2D"
28+
assert block_size == 32, "Only block_size=32 is supported for now"
29+
blocked_scales_list = []
30+
start_row_after_padding_list = [0]
31+
group_start_idx = 0
32+
for i, group_end_idx in enumerate(group_offs.tolist()):
33+
group_size = group_end_idx - group_start_idx
34+
prev_start_row_after_padding = start_row_after_padding_list[i]
35+
if group_size == 0:
36+
start_row_after_padding_list.append(prev_start_row_after_padding)
37+
continue
38+
39+
# Convert group scales to blocked format
40+
group_scales = x_scales[group_start_idx:group_end_idx]
41+
group_scales_blocked = _to_blocked(group_scales)
42+
blocked_scales_list.append(group_scales_blocked)
43+
44+
# Calculate the start row after padding
45+
scaling_groups_per_row = K // block_size
46+
rows_for_group = group_scales_blocked.numel() // scaling_groups_per_row
47+
new_start_row = prev_start_row_after_padding + rows_for_group
48+
start_row_after_padding_list.append(new_start_row)
49+
50+
# Update next group start index
51+
group_start_idx = group_end_idx
52+
53+
blocked_scales = torch.cat(blocked_scales_list, dim=0).contiguous()
54+
blocked_scales = blocked_scales.reshape(-1, K // 32)
55+
start_row_after_padding = torch.tensor(
56+
start_row_after_padding_list, device=x_scales.device, dtype=torch.int64
57+
)
58+
return blocked_scales, start_row_after_padding
59+
60+
61+
def to_blocked_per_group_3d(weight_scales: Tensor) -> Tensor:
62+
"""
63+
Convert scales to blocked format for each group for a 3D tensor (expert weights)
64+
65+
Args:
66+
scales: Tensor of shape (E, N, K//block_size)
67+
group_offs: Tensor of shape (num_groups,) which contains the end index of each group along the
68+
"""
69+
from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import _to_blocked
70+
71+
blocked_scales_list = []
72+
num_groups = weight_scales.shape[0]
73+
for i in range(num_groups):
74+
group_scales = weight_scales[i]
75+
group_scales_blocked = _to_blocked(group_scales)
76+
blocked_scales_list.append(group_scales_blocked)
77+
weight_scales_blocked = torch.stack(blocked_scales_list, dim=0).contiguous()
78+
weight_scales_blocked = weight_scales_blocked.reshape(num_groups, -1)
79+
return weight_scales_blocked
80+
81+
82+
def compute_per_group_blocked_scale_offsets(offsets: torch.Tensor):
83+
"""
84+
Rounds each integer in a 1D PyTorch tensor up to the nearest multiple of 128.
85+
86+
Args:
87+
offsets: A 1D PyTorch tensor of integers in ascending sorted order, representing the end index of each group along the Mg dimension.
88+
89+
Returns:
90+
- group_sizes: A 1D PyTorch tensor of integers representing the size of each group.
91+
- starting_row_after_padding: 1D integer tensor representing the starting row after padding each to blocked format.
92+
"""
93+
# Calculate group sizes
94+
zero = torch.tensor([0], dtype=offsets.dtype, device=offsets.device)
95+
group_sizes = torch.diff(offsets, prepend=zero).to(torch.int64)
96+
97+
# Round each group size up to the nearest multiple of 128
98+
rounded_group_sizes = ceil_div(group_sizes, 128) * 128
99+
100+
# Calculate the starting row after padding for each group
101+
starting_row_after_padding = torch.cumsum(rounded_group_sizes, dim=0)
102+
103+
# Must start with 0
104+
starting_row_after_padding = torch.cat([zero, starting_row_after_padding])
105+
return group_sizes, starting_row_after_padding
106+
107+
108+
# @triton_op("torchao::triton_mx_block_rearrange_per_group", mutates_args=())
109+
def triton_mx_block_rearrange_per_group(
110+
scales_tensor: torch.Tensor,
111+
offsets: torch.Tensor,
112+
) -> torch.Tensor:
113+
"""
114+
Rearranges an E8M0 tensor scale to block-scaled swizzle format.
115+
116+
This format is suitable for Tmem as described in NVIDIA documentation:
117+
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
118+
119+
Args:
120+
scales_tensor: Input tensor containing e8m0 scales for each logical group of a target tensor.
121+
blocked_scale_group_offsets: Tensor of shape (num_groups,) which contains the pre-computed start index of each group along the M dimension.
122+
123+
Returns:
124+
Rearranged tensor in block-scaled swizzle format
125+
"""
126+
assert scales_tensor.element_size() == 1, (
127+
"Expected element size to be 1 byte (8 bits)"
128+
)
129+
_, output_scales_group_offsets = compute_per_group_blocked_scale_offsets(offsets)
130+
rows, cols = scales_tensor.shape
131+
132+
# Calculate blocks needed
133+
num_groups = output_scales_group_offsets.numel()
134+
135+
# Final offset is the total number of rows in the tensor
136+
padded_rows = output_scales_group_offsets[-1]
137+
num_col_blocks = ceil_div(cols, 4)
138+
padded_cols = num_col_blocks * 4
139+
out = scales_tensor.new_empty((padded_rows, padded_cols))
140+
141+
# We probably want handle multiple blocks per tile but for now keep it simple
142+
BLOCK_ROWS, BLOCK_COLS = 128, 4
143+
144+
# Output block stride for the rearranged format
145+
output_stride_per_block = BLOCK_ROWS * BLOCK_COLS
146+
output_stride_per_row_of_blocks = (
147+
BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS)
148+
)
149+
150+
# We parallelize per group and per col block.
151+
# Rows per group is variable so we just loop through row blocks per group, per col block.
152+
grid = lambda META: (
153+
num_groups,
154+
num_col_blocks,
155+
)
156+
157+
triton_scale_swizzle_per_group[grid](
158+
# Input scales
159+
scales_tensor.view(torch.uint8),
160+
scales_tensor.stride(0),
161+
scales_tensor.stride(1),
162+
rows,
163+
cols,
164+
num_groups,
165+
# Original offsets (to read from)
166+
offsets,
167+
# Output scales tensor and group offsets after padding (to write to)
168+
out.view(torch.uint8),
169+
output_scales_group_offsets,
170+
output_stride_per_block,
171+
output_stride_per_row_of_blocks,
172+
BLOCK_ROWS=BLOCK_ROWS,
173+
BLOCK_COLS=BLOCK_COLS,
174+
)
175+
return out, output_scales_group_offsets
176+
177+
178+
@triton.jit
179+
def triton_scale_swizzle_per_group(
180+
scales_ptr, # (M, K//block_size)
181+
scales_stride_dim0,
182+
scales_stride_dim1,
183+
scale_rows,
184+
scale_cols,
185+
num_groups,
186+
orig_offsets, # (num_groups,)
187+
output_scales_ptr, # (rows + num_groups * 128, tl.cdiv(K, 4) * 4)
188+
output_scales_group_offsets, # (num_groups,)
189+
output_stride_per_block,
190+
output_stride_per_row_of_blocks,
191+
BLOCK_ROWS: tl.constexpr,
192+
BLOCK_COLS: tl.constexpr,
193+
):
194+
group_pid = tl.program_id(0)
195+
block_col_pid = tl.program_id(1)
196+
197+
# Input scales row range for this group
198+
input_group_start_row = tl.load(
199+
orig_offsets + group_pid - 1, mask=group_pid > 0, other=0
200+
)
201+
input_group_end_row = tl.load(
202+
orig_offsets + group_pid, mask=group_pid < num_groups, other=0
203+
)
204+
205+
# Output scales start row we will begin writing to
206+
output_group_start_row = tl.load(
207+
output_scales_group_offsets + group_pid, mask=group_pid < num_groups, other=0
208+
)
209+
210+
# Calculate destination indices for each row and col in block swizzled layout.
211+
# We can reuse this swizzle transformation on each block of data we read.
212+
row_offs = tl.arange(0, BLOCK_ROWS)[:, None]
213+
col_offs = tl.arange(0, BLOCK_COLS)[None, :]
214+
r_div_32 = row_offs // 32
215+
r_mod_32 = row_offs % 32
216+
217+
# Rearrange to (32, 4, 4) then to final (32, 16) coordinates
218+
dest_indices = r_mod_32 * 16 + r_div_32 * 4 + col_offs
219+
220+
# Flatten
221+
dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS))
222+
223+
# For this group and col block, we iterate through row blocks, reading (BLOCK_ROWS, BLOCK_COLS) from the input scales.
224+
# We track how many row blocks we have iterated through.
225+
block_row_id = 0
226+
current_start_row = input_group_start_row
227+
while current_start_row < input_group_end_row:
228+
# Read block of input scales
229+
block_row_offs = current_start_row + row_offs
230+
block_col_offs = block_col_pid * BLOCK_COLS + col_offs
231+
block_offs = (
232+
block_row_offs * scales_stride_dim0 + block_col_offs * scales_stride_dim1
233+
)
234+
mask = (block_row_offs < input_group_end_row) & (block_col_offs < scale_cols)
235+
input_scales = tl.load(scales_ptr + block_offs, mask=mask, other=0.0)
236+
scales_flat = tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS))
237+
238+
# Calculate block offset using provided output block stride
239+
output_block_offsets = (
240+
output_group_start_row * scale_cols
241+
+ (block_row_id * output_stride_per_row_of_blocks)
242+
+ (block_col_pid * output_stride_per_block)
243+
)
244+
245+
# Apply swizzling for write to gmem
246+
tl.store(
247+
output_scales_ptr + output_block_offsets + dest_indices_flat,
248+
scales_flat,
249+
)
250+
251+
# Update row block id to next block
252+
block_row_id += 1
253+
current_start_row += BLOCK_ROWS

torchao/prototype/moe_training/kernels/mxfp8.py renamed to torchao/prototype/moe_training/kernels/mxfp8_gemms.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
"If errors persist, please file a bug report."
2020
)
2121

22+
DEBUG = False
23+
2224

2325
@torch.library.custom_op("torchao::fbgemm_mxfp8_grouped_mm_2d_3d", mutates_args={})
2426
def fbgemm_mxfp8_grouped_mm_2d_3d(
@@ -108,6 +110,9 @@ def _log_inputs(
108110
group_sizes: torch.Tensor,
109111
starting_row_after_padding: torch.Tensor,
110112
):
113+
if not DEBUG:
114+
return
115+
111116
logger.info(f"offs: {offs}, dtype: {offs.dtype}")
112117
logger.info(
113118
f"A_fp8.shape: {A_fp8.shape}, stride: {A_fp8.stride()}, dtype: {A_fp8.dtype}"

0 commit comments

Comments
 (0)