Skip to content

[moe training] use smaller block sizes for per group scaling kernels to improve perf #2668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions torchao/prototype/moe_training/benchmarks/benchmark_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py

import itertools
import time
from dataclasses import dataclass
from typing import List

import torch
from tabulate import tabulate
from tqdm import tqdm
from triton.testing import do_bench

from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
triton_fp8_col_major_jagged_colwise_scales,
Expand Down Expand Up @@ -129,18 +129,15 @@ def run_triton(

# bench torch
compiled_run_torch = torch.compile(run_torch)
warmup(compiled_run_torch, input_row_major, input_col_major, offs)
start_time_ns = time.perf_counter_ns()
compiled_run_torch(input_row_major, input_col_major, offs)
torch_time_ns = time.perf_counter_ns() - start_time_ns
torch_time_us = torch_time_ns / 1e3
torch_time_us = benchmark_cuda_function_in_microseconds(
compiled_run_torch, input_row_major, input_col_major, offs
)

# bench triton
warmup(run_triton, input_row_major, input_col_major, offs)
start_time_ns = time.perf_counter_ns()
run_triton(input_row_major, input_col_major, offs)
triton_time_ns = time.perf_counter_ns() - start_time_ns
triton_time_us = triton_time_ns / 1e3
triton_time_us = benchmark_cuda_function_in_microseconds(
run_triton, input_row_major, input_col_major, offs
)

return ExperimentResult(
torch_time_us=torch_time_us,
Expand Down Expand Up @@ -173,6 +170,10 @@ def print_results(experiments: List[Experiment]):
print(tabulate(rows, headers=headers))


def benchmark_cuda_function_in_microseconds(f, *args):
return do_bench(lambda: f(*args), return_mode="median") * 1e3


def main():
torch.random.manual_seed(123)
configs = get_configs()
Expand Down
82 changes: 47 additions & 35 deletions torchao/prototype/moe_training/kernels/jagged_float8_scales.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
import triton
import triton.language as tl

from torchao.prototype.moe_training.utils import _is_column_major

EPS = 1e-12

FP8_DTYPE_MAP = {
Expand All @@ -33,13 +31,20 @@
torch.float64: tl.float64,
}

block_sizes = [128, 256]
block_sizes = [1, 16, 32, 64]
block_sizes_iter = [32, 64, 128, 256]
num_warps = [1, 4]
num_stages = [2, 3]
kernel_configs_2D = [
triton.Config(
{"BLOCK_SIZE_ROWS": block_size_rows, "BLOCK_SIZE_COLS": block_size_cols}
{"BLOCK_SIZE": block_size, "BLOCK_SIZE_ITER": block_size_iter},
num_warps=warps,
num_stages=stages,
)
for block_size_rows in block_sizes
for block_size_cols in block_sizes
for block_size in block_sizes
for block_size_iter in block_sizes_iter
for warps in num_warps
for stages in num_stages
]

from torch.library import triton_op, wrap_triton
Expand Down Expand Up @@ -68,7 +73,6 @@ def triton_fp8_row_major_jagged_rowwise_scales(
- jagged rowwise scales (i.e., rowwise scales for each group)
"""
assert hp_tensor.ndim == 2, "input tensor must be 2D"
assert hp_tensor.is_contiguous(), "input tensor must be contiguous"

num_elements = hp_tensor.numel()
tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype]
Expand All @@ -81,16 +85,14 @@ def triton_fp8_row_major_jagged_rowwise_scales(
n_groups = offsets.numel()

# allocate on-device buffers for output and scales
output_buffer = torch.empty_like(
hp_tensor, dtype=output_dtype, device=hp_tensor.device
)
output_buffer = torch.empty((m, k), dtype=output_dtype, device=hp_tensor.device)
scales_buffer = torch.empty(
(m * n_groups), dtype=torch.float32, device=hp_tensor.device
)

# parallelize across rows and groups (offsets)
grid = lambda meta: (
triton.cdiv(m, meta["BLOCK_SIZE_ROWS"]),
triton.cdiv(m, meta["BLOCK_SIZE"]),
offsets.numel(),
)
wrap_triton(_triton_fp8_row_major_jagged_rowwise_scales)[grid](
Expand All @@ -115,7 +117,13 @@ def triton_fp8_row_major_jagged_rowwise_scales(
return output_buffer, scales_buffer


@triton.autotune(configs=kernel_configs_2D, key=["num_elements"])
# This kernel is used on grad_output.t() which has shape (K, M),
# before the calculation `grad_B = grad_output_t @ input`.
# However, in this code, we use the conventional dim names (M, K)
# so the kernel is easily interpretable in a standalone fasion.
# The tokens per expert will vary per iteration, so don't want
# to recompile on `token` dim (K, in this case) changes.
@triton.autotune(configs=kernel_configs_2D, key=["M"])
@triton.jit
def _triton_fp8_row_major_jagged_rowwise_scales(
input_ptr,
Expand All @@ -134,8 +142,8 @@ def _triton_fp8_row_major_jagged_rowwise_scales(
input_dtype: tl.constexpr,
output_dtype: tl.constexpr,
round_scales_to_power_of_2: tl.constexpr,
BLOCK_SIZE_ROWS: tl.constexpr,
BLOCK_SIZE_COLS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BLOCK_SIZE_ITER: tl.constexpr,
EPS: tl.constexpr,
):
# parallel across rows and groups (offsets)
Expand All @@ -147,12 +155,12 @@ def _triton_fp8_row_major_jagged_rowwise_scales(
offsets_ptr + offset_idx - 1, mask=offset_idx > 0, other=0
)
group_col_end_idx = tl.load(offsets_ptr + offset_idx)
block_row_offs = block_row_id * BLOCK_SIZE_ROWS + tl.arange(0, BLOCK_SIZE_ROWS)
block_row_offs = block_row_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

# compute rowwise amaxes for this group
amax_buffer = tl.zeros((BLOCK_SIZE_ROWS,), dtype=input_dtype)
for col_start_idx in range(group_col_start_idx, group_col_end_idx, BLOCK_SIZE_COLS):
block_col_offs = col_start_idx + tl.arange(0, BLOCK_SIZE_COLS)
amax_buffer = tl.zeros((BLOCK_SIZE,), dtype=input_dtype)
for col_start_idx in range(group_col_start_idx, group_col_end_idx, BLOCK_SIZE_ITER):
block_col_offs = col_start_idx + tl.arange(0, BLOCK_SIZE_ITER)
block_offs = (
block_row_offs[:, None] * stride_input_row
+ block_col_offs[None, :] * stride_input_col
Expand Down Expand Up @@ -180,12 +188,12 @@ def _triton_fp8_row_major_jagged_rowwise_scales(
# store rowwise scales for each group in contiguous memory:
# [group0_row0, group_0_row1, ..., group2_row0, group2_row1]
scales_offs = block_row_offs + (M * offset_idx)
scales_mask = tl.arange(0, BLOCK_SIZE_ROWS) < M
scales_mask = tl.arange(0, BLOCK_SIZE) < M
tl.store(scales_ptr + scales_offs, scales, mask=scales_mask)

# perform float8 conversion for this group
for col_start_idx in range(group_col_start_idx, group_col_end_idx, BLOCK_SIZE_COLS):
block_col_offs = col_start_idx + tl.arange(0, BLOCK_SIZE_COLS)
for col_start_idx in range(group_col_start_idx, group_col_end_idx, BLOCK_SIZE_ITER):
block_col_offs = col_start_idx + tl.arange(0, BLOCK_SIZE_ITER)
block_offs = (
block_row_offs[:, None] * stride_input_row
+ block_col_offs[None, :] * stride_input_col
Expand Down Expand Up @@ -230,7 +238,6 @@ def triton_fp8_col_major_jagged_colwise_scales(
- jagged column-wise scales (i.e., column-wise scales for each group)
"""
assert hp_tensor.ndim == 2, "input tensor must be 2D"
assert _is_column_major(hp_tensor), "input tensor must be column-major"

num_elements = hp_tensor.numel()
tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype]
Expand All @@ -242,17 +249,18 @@ def triton_fp8_col_major_jagged_colwise_scales(
k, n = hp_tensor.shape
n_groups = offsets.numel()

# allocate on-device buffers for output and scales
# Output buffer in column major
output_buffer = torch.empty_like(
hp_tensor, dtype=output_dtype, device=hp_tensor.device
)
).as_strided(hp_tensor.size(), (1, k))

scales_buffer = torch.empty(
(n * n_groups), dtype=torch.float32, device=hp_tensor.device
)

# parallelize across columns and groups (offsets)
grid = lambda meta: (
triton.cdiv(n, meta["BLOCK_SIZE_COLS"]),
triton.cdiv(n, meta["BLOCK_SIZE"]),
offsets.numel(),
)
wrap_triton(_triton_fp8_col_major_jagged_colwise_scales)[grid](
Expand All @@ -277,7 +285,11 @@ def triton_fp8_col_major_jagged_colwise_scales(
return output_buffer, scales_buffer


@triton.autotune(configs=kernel_configs_2D, key=["num_elements"])
# This kernel is used on `input` which has shape (M, K),
# before the calculation `grad_B = grad_output_t @ input`.
# The tokens per expert will vary per iteration, so don't want
# to recompile on `token` dim (M) changes.
@triton.autotune(configs=kernel_configs_2D, key=["K"])
@triton.jit
def _triton_fp8_col_major_jagged_colwise_scales(
input_ptr,
Expand All @@ -296,8 +308,8 @@ def _triton_fp8_col_major_jagged_colwise_scales(
input_dtype: tl.constexpr,
output_dtype: tl.constexpr,
round_scales_to_power_of_2: tl.constexpr,
BLOCK_SIZE_ROWS: tl.constexpr,
BLOCK_SIZE_COLS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BLOCK_SIZE_ITER: tl.constexpr,
EPS: tl.constexpr,
):
# parallel across columns and groups (offsets)
Expand All @@ -309,12 +321,12 @@ def _triton_fp8_col_major_jagged_colwise_scales(
offsets_ptr + offset_idx - 1, mask=offset_idx > 0, other=0
)
group_row_end_idx = tl.load(offsets_ptr + offset_idx)
block_col_offs = block_col_id * BLOCK_SIZE_COLS + tl.arange(0, BLOCK_SIZE_COLS)
block_col_offs = block_col_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

# compute colwise amaxes for this group
amax_buffer = tl.zeros((BLOCK_SIZE_COLS,), dtype=input_dtype)
for row_start_idx in range(group_row_start_idx, group_row_end_idx, BLOCK_SIZE_ROWS):
block_row_offs = row_start_idx + tl.arange(0, BLOCK_SIZE_ROWS)
amax_buffer = tl.zeros((BLOCK_SIZE,), dtype=input_dtype)
for row_start_idx in range(group_row_start_idx, group_row_end_idx, BLOCK_SIZE_ITER):
block_row_offs = row_start_idx + tl.arange(0, BLOCK_SIZE_ITER)
block_offs = (
block_row_offs[:, None] * stride_input_row
+ block_col_offs[None, :] * stride_input_col
Expand Down Expand Up @@ -343,12 +355,12 @@ def _triton_fp8_col_major_jagged_colwise_scales(
# [group0_col0, group_0_col1, ..., group2_col0, group2_col1]
# note: input tensor is in col-major memory layout.
scales_offs = block_col_offs + (N * offset_idx)
scales_mask = tl.arange(0, BLOCK_SIZE_COLS) < N
scales_mask = tl.arange(0, BLOCK_SIZE) < N
tl.store(scales_ptr + scales_offs, scales, mask=scales_mask)

# perform float8 conversion for this group
for row_start_idx in range(group_row_start_idx, group_row_end_idx, BLOCK_SIZE_ROWS):
block_row_offs = row_start_idx + tl.arange(0, BLOCK_SIZE_ROWS)
for row_start_idx in range(group_row_start_idx, group_row_end_idx, BLOCK_SIZE_ITER):
block_row_offs = row_start_idx + tl.arange(0, BLOCK_SIZE_ITER)
block_offs = (
block_row_offs[:, None] * stride_input_row
+ block_col_offs[None, :] * stride_input_col
Expand Down
16 changes: 5 additions & 11 deletions torchao/prototype/moe_training/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,35 +217,29 @@ def backward(ctx, grad_output: torch.Tensor):
use_fast_accum=True,
)

# Convert transpose of grad_output to float8, row-major for left operand of grouped GEMM
# needed for grad_B: grad_output_t @ A
grad_output_t_row_major = grad_output.transpose(-2, -1).contiguous()

# Convert A to float8, column-major for right operand of grouped GEMM:
# needed for grad_B: grad_output @ A
A_col_major = A.transpose(-2, -1).contiguous().transpose(-2, -1)

# grad_B is a special case. both operands of the grouped gemm will be 2D with offsets determing the "groups."
# Compute scales for grad_output_t and A, which are both 2D tensors with offsets which define the "jagged" groups.

# Convert transpose of grad_output to float8, row-major for left operand of grouped GEMM
# needed for grad_B: grad_output_t @ A
grad_output_t_fp8_row_major, grad_output_t_scales = (
triton_fp8_row_major_jagged_rowwise_scales(
grad_output_t_row_major,
grad_output.transpose(-2, -1),
offs,
torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
)
)

A_fp8_col_major, A_scales = triton_fp8_col_major_jagged_colwise_scales(
A_col_major,
A,
offs,
torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
)

# Compute grad_B = grad_output_t @ A.
# grad_B = grad_output_t @ A
# grad_B = (N,M) @ (M,K) = (N,K)
assert not _is_column_major(grad_output_t_fp8_row_major), (
"grad_output_t must be row-major for grad_B = grad_output_t @ A"
)
Expand Down
Loading