Skip to content

Commit dfc7d87

Browse files
[moe training] add bench script for fp8 rowwise kernels and update autotune configs
stack-info: PR: #2697, branch: danielvegamyhre/stack/31
1 parent 4d57aa3 commit dfc7d87

File tree

4 files changed

+357
-11
lines changed

4 files changed

+357
-11
lines changed
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
8+
import itertools
9+
from dataclasses import dataclass
10+
from typing import List
11+
12+
import torch
13+
from tabulate import tabulate
14+
from tqdm import tqdm
15+
from triton.testing import do_bench
16+
17+
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
18+
triton_fp8_col_major_jagged_colwise_scales,
19+
triton_fp8_row_major_jagged_rowwise_scales,
20+
)
21+
from torchao.prototype.moe_training.utils import (
22+
torch_to_float8_per_group_colwise,
23+
torch_to_float8_per_group_rowwise,
24+
)
25+
26+
device = torch.device("cuda")
27+
28+
# Needed since changing args to function causes recompiles
29+
torch._dynamo.config.cache_size_limit = 1000
30+
31+
32+
@dataclass(frozen=True)
33+
class ExperimentConfig:
34+
high_precision_dtype: torch.dtype
35+
input_shape: tuple[int]
36+
n_groups: int
37+
38+
39+
@dataclass(frozen=True)
40+
class ExperimentResult:
41+
torch_time_us: float
42+
triton_time_us: float
43+
44+
45+
@dataclass(frozen=True)
46+
class Experiment:
47+
config: ExperimentConfig
48+
result: ExperimentResult
49+
50+
51+
def get_configs() -> List[ExperimentConfig]:
52+
input_shapes = [(2**8, 4096), (2**12, 4096), (2**16, 4096)]
53+
n_groups_list = [4, 8, 16]
54+
high_precision_dtypes = [torch.bfloat16]
55+
configs = []
56+
for input_shape, n_groups, high_precision_dtype in itertools.product(
57+
input_shapes, n_groups_list, high_precision_dtypes
58+
):
59+
configs.append(
60+
ExperimentConfig(
61+
input_shape=input_shape,
62+
n_groups=n_groups,
63+
high_precision_dtype=high_precision_dtype,
64+
)
65+
)
66+
return configs
67+
68+
69+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
70+
# define test inputs
71+
input_tensor = torch.randn(
72+
*config.input_shape,
73+
dtype=config.high_precision_dtype,
74+
device=device,
75+
)
76+
input_row_major = input_tensor.clone().detach()
77+
input_col_major = input_tensor.clone().detach().t()
78+
79+
# - configure input to be row-major with groups divided along the column dimension,
80+
# representing the left operand of grad_weight = grad_output_t @ input
81+
# that occurs in the backward pass of the differentiable scaled grouped mm.
82+
# - the transposed tensor in col-major format with groups along the row dimension,
83+
# which represents the right operand.
84+
group_size = input_row_major.shape[1] // config.n_groups
85+
n_groups = config.n_groups
86+
offs = torch.arange(
87+
group_size,
88+
group_size * n_groups + 1,
89+
group_size,
90+
device=device,
91+
dtype=torch.int32,
92+
)
93+
94+
def warmup(func, *args, **kwargs):
95+
for _ in range(10):
96+
func(*args, **kwargs)
97+
98+
def run_torch(
99+
input_row_major: torch.Tensor, input_col_major: torch.Tensor, offs: torch.Tensor
100+
):
101+
_ = torch_to_float8_per_group_rowwise(
102+
input_row_major,
103+
offs,
104+
target_dtype=torch.float8_e4m3fn,
105+
round_scales_to_power_of_2=True,
106+
)
107+
_ = torch_to_float8_per_group_colwise(
108+
input_col_major,
109+
offs,
110+
target_dtype=torch.float8_e4m3fn,
111+
round_scales_to_power_of_2=True,
112+
)
113+
114+
def run_triton(
115+
input_row_major: torch.Tensor, input_col_major: torch.Tensor, offs: torch.Tensor
116+
):
117+
_ = triton_fp8_row_major_jagged_rowwise_scales(
118+
input_row_major,
119+
offs,
120+
output_dtype=torch.float8_e4m3fn,
121+
round_scales_to_power_of_2=True,
122+
)
123+
_ = triton_fp8_col_major_jagged_colwise_scales(
124+
input_col_major,
125+
offs,
126+
output_dtype=torch.float8_e4m3fn,
127+
round_scales_to_power_of_2=True,
128+
)
129+
130+
# bench torch
131+
compiled_run_torch = torch.compile(run_torch)
132+
torch_time_us = benchmark_cuda_function_in_microseconds(
133+
compiled_run_torch, input_row_major, input_col_major, offs
134+
)
135+
136+
# bench triton
137+
warmup(run_triton, input_row_major, input_col_major, offs)
138+
triton_time_us = benchmark_cuda_function_in_microseconds(
139+
run_triton, input_row_major, input_col_major, offs
140+
)
141+
142+
return ExperimentResult(
143+
torch_time_us=torch_time_us,
144+
triton_time_us=triton_time_us,
145+
)
146+
147+
148+
def print_results(experiments: List[Experiment]):
149+
headers = [
150+
"input_shape",
151+
"n_groups",
152+
"high_precision_dtype",
153+
"torch_time_us",
154+
"triton_time_us",
155+
]
156+
rows = []
157+
for experiment in experiments:
158+
input_shape = (
159+
f"({experiment.config.input_shape[0]}, {experiment.config.input_shape[1]})"
160+
)
161+
rows.append(
162+
[
163+
input_shape,
164+
experiment.config.n_groups,
165+
experiment.config.high_precision_dtype,
166+
experiment.result.torch_time_us,
167+
experiment.result.triton_time_us,
168+
]
169+
)
170+
print(tabulate(rows, headers=headers))
171+
172+
173+
def benchmark_cuda_function_in_microseconds(f, *args):
174+
return do_bench(lambda: f(*args), return_mode="median") * 1e3
175+
176+
177+
def main():
178+
torch.random.manual_seed(123)
179+
configs = get_configs()
180+
results = []
181+
for config in tqdm(configs):
182+
result = run_experiment(config)
183+
results.append(Experiment(config=config, result=result))
184+
185+
# Use Tabulate to print results
186+
print_results(results)
187+
188+
189+
if __name__ == "__main__":
190+
main()
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
8+
import itertools
9+
from dataclasses import dataclass
10+
from typing import List
11+
12+
import torch
13+
from tabulate import tabulate
14+
from tqdm import tqdm
15+
from triton.testing import do_bench
16+
17+
from torchao.prototype.moe_training.kernels.float8_rowwise import (
18+
triton_fp8_rowwise_3d_transpose_rhs,
19+
)
20+
from torchao.prototype.moe_training.utils import (
21+
torch_to_3d_rowwise_float8_transpose_rhs,
22+
)
23+
24+
device = torch.device("cuda")
25+
26+
# Needed since changing args to function causes recompiles
27+
torch._dynamo.config.cache_size_limit = 1000
28+
29+
30+
@dataclass(frozen=True)
31+
class ExperimentConfig:
32+
high_precision_dtype: torch.dtype
33+
input_shape: tuple[int]
34+
35+
36+
@dataclass(frozen=True)
37+
class ExperimentResult:
38+
torch_time_us: float
39+
triton_time_us: float
40+
41+
42+
@dataclass(frozen=True)
43+
class Experiment:
44+
config: ExperimentConfig
45+
result: ExperimentResult
46+
47+
48+
def get_configs() -> List[ExperimentConfig]:
49+
# Llama4 and DeepSeekV3 shapes
50+
input_shapes = [(8, 4096, 1024), (16, 5120 * 4, 5120)]
51+
high_precision_dtypes = [torch.bfloat16]
52+
configs = []
53+
for input_shape, high_precision_dtype in itertools.product(
54+
input_shapes, high_precision_dtypes
55+
):
56+
configs.append(
57+
ExperimentConfig(
58+
input_shape=input_shape,
59+
high_precision_dtype=high_precision_dtype,
60+
)
61+
)
62+
return configs
63+
64+
65+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
66+
# Expert weights will be passed in transposed and column major in practice
67+
input_tensor = torch.randn(
68+
*config.input_shape,
69+
dtype=config.high_precision_dtype,
70+
device=device,
71+
).transpose(-2, -1)
72+
73+
def warmup(func, *args, **kwargs):
74+
for _ in range(10):
75+
func(*args, **kwargs)
76+
77+
def run_torch(input_tensor: torch.Tensor):
78+
out = torch_to_3d_rowwise_float8_transpose_rhs(
79+
input_tensor,
80+
target_dtype=torch.float8_e4m3fn,
81+
round_scales_to_power_of_2=True,
82+
)
83+
torch.cuda.synchronize()
84+
return out
85+
86+
def run_triton(input_tensor: torch.Tensor):
87+
_ = triton_fp8_rowwise_3d_transpose_rhs(
88+
input_tensor,
89+
output_dtype=torch.float8_e4m3fn,
90+
round_scales_to_power_of_2=True,
91+
)
92+
torch.cuda.synchronize()
93+
94+
# bench torch
95+
compiled_run_torch = torch.compile(run_torch)
96+
warmup(run_torch, input_tensor)
97+
torch_time_us = benchmark_cuda_function_in_microseconds(
98+
compiled_run_torch,
99+
input_tensor,
100+
)
101+
102+
# bench triton
103+
warmup(run_triton, input_tensor)
104+
triton_time_us = benchmark_cuda_function_in_microseconds(
105+
run_triton,
106+
input_tensor,
107+
)
108+
109+
return ExperimentResult(
110+
torch_time_us=torch_time_us,
111+
triton_time_us=triton_time_us,
112+
)
113+
114+
115+
def print_results(experiments: List[Experiment]):
116+
headers = [
117+
"input_shape",
118+
"torch_time_us",
119+
"triton_time_us",
120+
]
121+
rows = []
122+
for experiment in experiments:
123+
input_shape = f"({experiment.config.input_shape[0]}, {experiment.config.input_shape[1], experiment.config.input_shape[2]})"
124+
rows.append(
125+
[
126+
input_shape,
127+
experiment.result.torch_time_us,
128+
experiment.result.triton_time_us,
129+
]
130+
)
131+
print(tabulate(rows, headers=headers))
132+
133+
134+
def benchmark_cuda_function_in_microseconds(f, *args):
135+
return do_bench(lambda: f(*args), return_mode="median") * 1e3
136+
137+
138+
def main():
139+
torch.random.manual_seed(123)
140+
configs = get_configs()
141+
results = []
142+
for config in tqdm(configs):
143+
result = run_experiment(config)
144+
results.append(Experiment(config=config, result=result))
145+
146+
# Use Tabulate to print results
147+
print_results(results)
148+
149+
150+
if __name__ == "__main__":
151+
main()

torchao/prototype/moe_training/kernels/float8_rowwise.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,18 @@
2626
torch.float64: tl.float64,
2727
}
2828

29-
block_sizes = [16]
30-
num_warps = [4]
31-
num_stages = [2]
29+
block_sizes_n = [32, 128, 512] # large dim (output_features)
30+
block_sizes_k = [32, 128, 512] # small dim (input_features)
31+
num_warps = [8]
32+
num_stages = [2, 3]
3233
kernel_configs_2D = [
3334
triton.Config(
34-
{"BLOCK_SIZE_N": block_size, "BLOCK_SIZE_K": block_size * 2},
35+
{"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k},
3536
num_warps=warps,
3637
num_stages=stages,
3738
)
38-
for block_size in block_sizes
39+
for block_size_n in block_sizes_n
40+
for block_size_k in block_sizes_k
3941
for warps in num_warps
4042
for stages in num_stages
4143
]
@@ -62,8 +64,10 @@ def triton_fp8_rowwise_3d_transpose_rhs(
6264

6365
# allocate on-device buffers for output and scales
6466
# output shape = input.transpose(-2, -1).shape = (E, N, K) in column major layout
65-
output_buffer = torch.empty((e, k, n), dtype=output_dtype, device=hp_tensor.device)
66-
output_buffer = output_buffer.transpose(-2, -1)
67+
output_buffer = torch.empty(
68+
(e, n, k), dtype=output_dtype, device=hp_tensor.device
69+
).as_strided((e, n, k), (n * k, 1, n))
70+
6771
scales_buffer = torch.full(
6872
(e, k), float("inf"), dtype=torch.float32, device=hp_tensor.device
6973
)

0 commit comments

Comments
 (0)