Skip to content

Commit 4ecc89e

Browse files
[mxfp8 moe training] add per group blocked scale kernels (#2886)
1 parent 2f78cfe commit 4ecc89e

File tree

7 files changed

+512
-98
lines changed

7 files changed

+512
-98
lines changed

benchmarks/prototype/moe_training/bench_2d-3d_grouped_gemm.py renamed to benchmarks/prototype/moe_training/benchmark_2d_3d_grouped_gemms.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212
import torch
1313
from tabulate import tabulate
1414
from tqdm import tqdm
15-
from utils import benchmark_cuda_function_in_microseconds
1615

16+
from benchmarks.utils import benchmark_cuda_function_in_microseconds
1717
from torchao.float8.config import ScalingGranularity
1818
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
19+
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
20+
torch_to_blocked_per_group_2d,
21+
torch_to_blocked_per_group_3d,
22+
)
1923
from torchao.prototype.moe_training.utils import generate_jagged_offs
2024
from torchao.prototype.mx_formats.mx_tensor import to_mx
21-
from torchao.prototype.mx_formats.utils import (
22-
to_blocked_per_group_2d,
23-
to_blocked_per_group_3d,
24-
)
2525

2626
device = torch.device("cuda")
2727

@@ -50,9 +50,9 @@ class Experiment:
5050
def get_configs() -> List[ExperimentConfig]:
5151
# Llama4 shapes
5252
M = [16640]
53-
K = [5120]
54-
N = [8192]
55-
E = [16]
53+
K = [2048, 5120, 8192]
54+
N = [2048, 5120, 8192]
55+
E = [1, 2, 4, 8]
5656
configs = []
5757
for e, m, n, k in itertools.product(
5858
E,
@@ -196,10 +196,10 @@ def bench_mxfp8_grouped_mm(A, B_t, offs, block_size=32) -> float:
196196

197197
# Convert scales for each group to blocked format.
198198
Mg, K = A_fp8.shape
199-
A_scales_blocked, starting_row_after_padding = to_blocked_per_group_2d(
199+
A_scales_blocked, starting_row_after_padding = torch_to_blocked_per_group_2d(
200200
A_scales, offs, Mg, K
201201
)
202-
B_scales_blocked = to_blocked_per_group_3d(B_scales)
202+
B_scales_blocked = torch_to_blocked_per_group_3d(B_scales)
203203

204204
# From this, we compute `group_sizes` and `starting_row_after_padding`:
205205
# group_sizes = [32, 32, 64]
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
8+
import itertools
9+
from dataclasses import dataclass
10+
from typing import List
11+
12+
import torch
13+
from tabulate import tabulate
14+
from tqdm import tqdm
15+
16+
from benchmarks.utils import benchmark_cuda_function_in_microseconds
17+
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
18+
compute_per_group_blocked_scale_offsets,
19+
torch_to_blocked_per_group_2d,
20+
triton_mx_block_rearrange_per_group_2d,
21+
)
22+
from torchao.prototype.moe_training.utils import generate_jagged_offs
23+
24+
device = torch.device("cuda")
25+
26+
# Needed since changing args to function causes recompiles
27+
torch._dynamo.config.cache_size_limit = 1000
28+
29+
30+
@dataclass(frozen=True)
31+
class ExperimentConfig:
32+
input_shape: tuple[int]
33+
num_groups: int
34+
35+
36+
@dataclass(frozen=True)
37+
class ExperimentResult:
38+
torch_time_us: float
39+
triton_time_us: float
40+
torch_mem_bw_gbps: float
41+
triton_mem_bw_gbps: float
42+
43+
44+
@dataclass(frozen=True)
45+
class Experiment:
46+
config: ExperimentConfig
47+
result: ExperimentResult
48+
49+
50+
def get_configs() -> List[ExperimentConfig]:
51+
# Llama4 shapes. Input activations are scaled along K dim.
52+
block_size = 32
53+
input_shapes = [
54+
(16640, 5120 // block_size),
55+
]
56+
num_groups = [16]
57+
configs = []
58+
for shape, groups in itertools.product(
59+
input_shapes,
60+
num_groups,
61+
):
62+
configs.append(
63+
ExperimentConfig(
64+
input_shape=shape,
65+
num_groups=groups,
66+
)
67+
)
68+
return configs
69+
70+
71+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
72+
input_shape, num_groups = config.input_shape, config.num_groups
73+
input_tensor = torch.randint(
74+
low=0,
75+
high=256,
76+
size=input_shape,
77+
dtype=torch.uint8,
78+
device=device,
79+
)
80+
81+
Mg, K = input_shape
82+
input_group_offsets = generate_jagged_offs(num_groups, Mg, multiple_of=32)
83+
84+
# bench torch
85+
compiled_run_torch = torch.compile(torch_to_blocked_per_group_2d)
86+
torch_out_scales, torch_group_offs = compiled_run_torch(
87+
input_tensor, input_group_offsets, Mg, K
88+
)
89+
torch_time_us = benchmark_cuda_function_in_microseconds(
90+
compiled_run_torch,
91+
input_tensor,
92+
input_group_offsets,
93+
Mg,
94+
K,
95+
)
96+
97+
# bench triton
98+
_, output_group_offsets = compute_per_group_blocked_scale_offsets(
99+
input_group_offsets
100+
)
101+
triton_out_scales = triton_mx_block_rearrange_per_group_2d(
102+
input_tensor,
103+
input_group_offsets,
104+
output_group_offsets,
105+
)
106+
triton_time_us = benchmark_cuda_function_in_microseconds(
107+
triton_mx_block_rearrange_per_group_2d,
108+
input_tensor,
109+
input_group_offsets,
110+
output_group_offsets,
111+
)
112+
113+
# mem bw calculations
114+
bytes_per_input_el = torch.finfo(torch.float8_e8m0fnu).bits / 8
115+
bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8
116+
117+
read_bytes = input_tensor.numel() * bytes_per_input_el
118+
write_bytes = triton_out_scales.numel() * bytes_per_output_el
119+
120+
torch_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_time_us / 1e6)
121+
triton_mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6)
122+
123+
return ExperimentResult(
124+
torch_time_us=torch_time_us,
125+
triton_time_us=triton_time_us,
126+
torch_mem_bw_gbps=torch_mem_bw_gbps,
127+
triton_mem_bw_gbps=triton_mem_bw_gbps,
128+
)
129+
130+
131+
def print_results(experiments: List[Experiment]):
132+
headers = [
133+
"input_shape",
134+
"torch_time_us",
135+
"triton_time_us",
136+
"torch_mem_bw_gbps",
137+
"triton_mem_bw_gbps",
138+
"triton_speedup",
139+
]
140+
rows = []
141+
for experiment in experiments:
142+
input_shape = (
143+
f"({experiment.config.input_shape[0]}, {experiment.config.input_shape[1]})"
144+
)
145+
rows.append(
146+
[
147+
input_shape,
148+
experiment.result.torch_time_us,
149+
experiment.result.triton_time_us,
150+
round(experiment.result.torch_mem_bw_gbps, 3),
151+
round(experiment.result.triton_mem_bw_gbps, 3),
152+
f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x",
153+
]
154+
)
155+
print(tabulate(rows, headers=headers))
156+
157+
158+
def main():
159+
torch.random.manual_seed(123)
160+
configs = get_configs()
161+
results = []
162+
for config in tqdm(configs):
163+
result = run_experiment(config)
164+
results.append(Experiment(config=config, result=result))
165+
166+
# Use Tabulate to print results
167+
print_results(results)
168+
169+
170+
if __name__ == "__main__":
171+
main()

test/prototype/moe_training/test_kernels.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,19 @@
2121
triton_fp8_per_group_colwise_scales,
2222
triton_fp8_per_group_rowwise_scales,
2323
)
24+
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
25+
compute_per_group_blocked_scale_offsets,
26+
torch_to_blocked_per_group_2d,
27+
triton_mx_block_rearrange_per_group_2d,
28+
)
2429
from torchao.prototype.moe_training.utils import (
2530
_is_column_major,
31+
generate_jagged_offs,
2632
torch_to_3d_rowwise_float8_transpose_rhs,
2733
torch_to_float8_per_group_colwise,
2834
torch_to_float8_per_group_rowwise,
2935
)
36+
from torchao.prototype.mx_formats.mx_tensor import to_mx
3037
from torchao.testing.utils import skip_if_rocm
3138

3239

@@ -195,3 +202,41 @@ def test_fp8_rowwise_3d_transpose_rhs_reduction(round_scales_to_power_of_2: bool
195202
assert ref_fp8.shape == triton_fp8.shape, "output shapes not equal"
196203
assert ref_fp8.stride() == triton_fp8.stride(), "output strides not equal"
197204
assert torch.allclose(ref_fp8, triton_fp8, rtol=0, atol=0), "fp8 data not equal"
205+
206+
207+
@skip_if_rocm("ROCm enablement in progress")
208+
@pytest.mark.parametrize(
209+
"m,k,n_groups", [(256, 256, 4), (16640, 5120, 16), (16640, 8192, 16)]
210+
)
211+
def test_mxfp8_per_group_blocked_scales_2d(
212+
m: int,
213+
k: int,
214+
n_groups: int,
215+
):
216+
device = "cuda"
217+
block_size = 32
218+
input_data = torch.randn(m, k, device=device)
219+
e8m0_scales, _ = to_mx(
220+
input_data, elem_dtype=torch.float8_e4m3fn, block_size=block_size
221+
)
222+
input_group_offsets = generate_jagged_offs(
223+
n_groups, m, multiple_of=block_size, device=device
224+
)
225+
226+
# torch reference
227+
ref_out_scales, _ = torch_to_blocked_per_group_2d(
228+
e8m0_scales, input_group_offsets, m, k, block_size=block_size
229+
)
230+
231+
# triton kernel
232+
_, output_group_offsets = compute_per_group_blocked_scale_offsets(
233+
input_group_offsets
234+
)
235+
triton_out_scales = triton_mx_block_rearrange_per_group_2d(
236+
e8m0_scales,
237+
input_group_offsets,
238+
output_group_offsets,
239+
)
240+
assert torch.allclose(ref_out_scales, triton_out_scales, atol=0, rtol=0), (
241+
"blocked scales not equal"
242+
)

torchao/prototype/moe_training/kernels/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@
77
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
88
triton_fp8_per_group_rowwise_scales as triton_fp8_per_group_rowwise_scales,
99
)
10-
from torchao.prototype.moe_training.kernels.mxfp8 import (
10+
from torchao.prototype.moe_training.kernels.mxfp8_gemms import (
1111
fbgemm_mxfp8_grouped_mm_2d_3d as fbgemm_mxfp8_grouped_mm_2d_3d,
1212
)

0 commit comments

Comments
 (0)