Skip to content

[moe training] add bench script for fp8 rowwise kernels and update autotune configs #2697

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py

import itertools
from dataclasses import dataclass
from typing import List

import torch
from tabulate import tabulate
from tqdm import tqdm
from triton.testing import do_bench

from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
triton_fp8_col_major_jagged_colwise_scales,
triton_fp8_row_major_jagged_rowwise_scales,
)
from torchao.prototype.moe_training.utils import (
torch_to_float8_per_group_colwise,
torch_to_float8_per_group_rowwise,
)

device = torch.device("cuda")

# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 1000


@dataclass(frozen=True)
class ExperimentConfig:
high_precision_dtype: torch.dtype
input_shape: tuple[int]
n_groups: int


@dataclass(frozen=True)
class ExperimentResult:
torch_time_us: float
triton_time_us: float


@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
result: ExperimentResult


def get_configs() -> List[ExperimentConfig]:
input_shapes = [(2**8, 4096), (2**12, 4096), (2**16, 4096)]
n_groups_list = [4, 8, 16]
high_precision_dtypes = [torch.bfloat16]
configs = []
for input_shape, n_groups, high_precision_dtype in itertools.product(
input_shapes, n_groups_list, high_precision_dtypes
):
configs.append(
ExperimentConfig(
input_shape=input_shape,
n_groups=n_groups,
high_precision_dtype=high_precision_dtype,
)
)
return configs


def run_experiment(config: ExperimentConfig) -> ExperimentResult:
# define test inputs
input_tensor = torch.randn(
*config.input_shape,
dtype=config.high_precision_dtype,
device=device,
)
input_row_major = input_tensor.clone().detach()
input_col_major = input_tensor.clone().detach().t()

# - configure input to be row-major with groups divided along the column dimension,
# representing the left operand of grad_weight = grad_output_t @ input
# that occurs in the backward pass of the differentiable scaled grouped mm.
# - the transposed tensor in col-major format with groups along the row dimension,
# which represents the right operand.
group_size = input_row_major.shape[1] // config.n_groups
n_groups = config.n_groups
offs = torch.arange(
group_size,
group_size * n_groups + 1,
group_size,
device=device,
dtype=torch.int32,
)

def warmup(func, *args, **kwargs):
for _ in range(10):
func(*args, **kwargs)

def run_torch(
input_row_major: torch.Tensor, input_col_major: torch.Tensor, offs: torch.Tensor
):
_ = torch_to_float8_per_group_rowwise(
input_row_major,
offs,
target_dtype=torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
)
_ = torch_to_float8_per_group_colwise(
input_col_major,
offs,
target_dtype=torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
)

def run_triton(
input_row_major: torch.Tensor, input_col_major: torch.Tensor, offs: torch.Tensor
):
_ = triton_fp8_row_major_jagged_rowwise_scales(
input_row_major,
offs,
output_dtype=torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
)
_ = triton_fp8_col_major_jagged_colwise_scales(
input_col_major,
offs,
output_dtype=torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
)

# bench torch
compiled_run_torch = torch.compile(run_torch)
torch_time_us = benchmark_cuda_function_in_microseconds(
compiled_run_torch, input_row_major, input_col_major, offs
)

# bench triton
warmup(run_triton, input_row_major, input_col_major, offs)
triton_time_us = benchmark_cuda_function_in_microseconds(
run_triton, input_row_major, input_col_major, offs
)

return ExperimentResult(
torch_time_us=torch_time_us,
triton_time_us=triton_time_us,
)


def print_results(experiments: List[Experiment]):
headers = [
"input_shape",
"n_groups",
"high_precision_dtype",
"torch_time_us",
"triton_time_us",
]
rows = []
for experiment in experiments:
input_shape = (
f"({experiment.config.input_shape[0]}, {experiment.config.input_shape[1]})"
)
rows.append(
[
input_shape,
experiment.config.n_groups,
experiment.config.high_precision_dtype,
experiment.result.torch_time_us,
experiment.result.triton_time_us,
]
)
print(tabulate(rows, headers=headers))


def benchmark_cuda_function_in_microseconds(f, *args):
return do_bench(lambda: f(*args), return_mode="median") * 1e3


def main():
torch.random.manual_seed(123)
configs = get_configs()
results = []
for config in tqdm(configs):
result = run_experiment(config)
results.append(Experiment(config=config, result=result))

# Use Tabulate to print results
print_results(results)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py

import itertools
from dataclasses import dataclass
from typing import List

import torch
from tabulate import tabulate
from tqdm import tqdm
from triton.testing import do_bench

from torchao.prototype.moe_training.kernels.float8_rowwise import (
triton_fp8_rowwise_3d_transpose_rhs,
)
from torchao.prototype.moe_training.utils import (
torch_to_3d_rowwise_float8_transpose_rhs,
)

device = torch.device("cuda")

# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 1000


@dataclass(frozen=True)
class ExperimentConfig:
high_precision_dtype: torch.dtype
input_shape: tuple[int]


@dataclass(frozen=True)
class ExperimentResult:
torch_time_us: float
triton_time_us: float


@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
result: ExperimentResult


def get_configs() -> List[ExperimentConfig]:
# Llama4 and DeepSeekV3 shapes
input_shapes = [(8, 4096, 1024), (16, 5120 * 4, 5120)]
high_precision_dtypes = [torch.bfloat16]
configs = []
for input_shape, high_precision_dtype in itertools.product(
input_shapes, high_precision_dtypes
):
configs.append(
ExperimentConfig(
input_shape=input_shape,
high_precision_dtype=high_precision_dtype,
)
)
return configs


def run_experiment(config: ExperimentConfig) -> ExperimentResult:
# Expert weights will be passed in transposed and column major in practice
input_tensor = torch.randn(
*config.input_shape,
dtype=config.high_precision_dtype,
device=device,
).transpose(-2, -1)

def warmup(func, *args, **kwargs):
for _ in range(10):
func(*args, **kwargs)

def run_torch(input_tensor: torch.Tensor):
out = torch_to_3d_rowwise_float8_transpose_rhs(
input_tensor,
target_dtype=torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
)
torch.cuda.synchronize()
return out

def run_triton(input_tensor: torch.Tensor):
_ = triton_fp8_rowwise_3d_transpose_rhs(
input_tensor,
output_dtype=torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
)
torch.cuda.synchronize()

# bench torch
compiled_run_torch = torch.compile(run_torch)
warmup(run_torch, input_tensor)
torch_time_us = benchmark_cuda_function_in_microseconds(
compiled_run_torch,
input_tensor,
)

# bench triton
warmup(run_triton, input_tensor)
triton_time_us = benchmark_cuda_function_in_microseconds(
run_triton,
input_tensor,
)

return ExperimentResult(
torch_time_us=torch_time_us,
triton_time_us=triton_time_us,
)


def print_results(experiments: List[Experiment]):
headers = [
"input_shape",
"torch_time_us",
"triton_time_us",
]
rows = []
for experiment in experiments:
input_shape = f"({experiment.config.input_shape[0]}, {experiment.config.input_shape[1], experiment.config.input_shape[2]})"
rows.append(
[
input_shape,
experiment.result.torch_time_us,
experiment.result.triton_time_us,
]
)
print(tabulate(rows, headers=headers))


def benchmark_cuda_function_in_microseconds(f, *args):
return do_bench(lambda: f(*args), return_mode="median") * 1e3


def main():
torch.random.manual_seed(123)
configs = get_configs()
results = []
for config in tqdm(configs):
result = run_experiment(config)
results.append(Experiment(config=config, result=result))

# Use Tabulate to print results
print_results(results)


if __name__ == "__main__":
main()
18 changes: 11 additions & 7 deletions torchao/prototype/moe_training/kernels/float8_rowwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,18 @@
torch.float64: tl.float64,
}

block_sizes = [16]
num_warps = [4]
num_stages = [2]
block_sizes_n = [32, 128, 512] # large dim (output_features)
block_sizes_k = [32, 128, 512] # small dim (input_features)
num_warps = [8]
num_stages = [2, 3]
kernel_configs_2D = [
triton.Config(
{"BLOCK_SIZE_N": block_size, "BLOCK_SIZE_K": block_size * 2},
{"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k},
num_warps=warps,
num_stages=stages,
)
for block_size in block_sizes
for block_size_n in block_sizes_n
for block_size_k in block_sizes_k
for warps in num_warps
for stages in num_stages
]
Expand All @@ -62,8 +64,10 @@ def triton_fp8_rowwise_3d_transpose_rhs(

# allocate on-device buffers for output and scales
# output shape = input.transpose(-2, -1).shape = (E, N, K) in column major layout
output_buffer = torch.empty((e, k, n), dtype=output_dtype, device=hp_tensor.device)
output_buffer = output_buffer.transpose(-2, -1)
output_buffer = torch.empty(
(e, n, k), dtype=output_dtype, device=hp_tensor.device
).as_strided((e, n, k), (n * k, 1, n))

scales_buffer = torch.full(
(e, k), float("inf"), dtype=torch.float32, device=hp_tensor.device
)
Expand Down
Loading