diff --git a/torchao/prototype/moe_training/benchmarks/benchmark_kernels.py b/torchao/prototype/moe_training/benchmarks/benchmark_kernels.py index 37701e6545..7068fe5b58 100644 --- a/torchao/prototype/moe_training/benchmarks/benchmark_kernels.py +++ b/torchao/prototype/moe_training/benchmarks/benchmark_kernels.py @@ -6,13 +6,13 @@ # this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py import itertools -import time from dataclasses import dataclass from typing import List import torch from tabulate import tabulate from tqdm import tqdm +from triton.testing import do_bench from torchao.prototype.moe_training.kernels.jagged_float8_scales import ( triton_fp8_col_major_jagged_colwise_scales, @@ -129,18 +129,15 @@ def run_triton( # bench torch compiled_run_torch = torch.compile(run_torch) - warmup(compiled_run_torch, input_row_major, input_col_major, offs) - start_time_ns = time.perf_counter_ns() - compiled_run_torch(input_row_major, input_col_major, offs) - torch_time_ns = time.perf_counter_ns() - start_time_ns - torch_time_us = torch_time_ns / 1e3 + torch_time_us = benchmark_cuda_function_in_microseconds( + compiled_run_torch, input_row_major, input_col_major, offs + ) # bench triton warmup(run_triton, input_row_major, input_col_major, offs) - start_time_ns = time.perf_counter_ns() - run_triton(input_row_major, input_col_major, offs) - triton_time_ns = time.perf_counter_ns() - start_time_ns - triton_time_us = triton_time_ns / 1e3 + triton_time_us = benchmark_cuda_function_in_microseconds( + run_triton, input_row_major, input_col_major, offs + ) return ExperimentResult( torch_time_us=torch_time_us, @@ -173,6 +170,10 @@ def print_results(experiments: List[Experiment]): print(tabulate(rows, headers=headers)) +def benchmark_cuda_function_in_microseconds(f, *args): + return do_bench(lambda: f(*args), return_mode="median") * 1e3 + + def main(): torch.random.manual_seed(123) configs = get_configs() diff --git a/torchao/prototype/moe_training/kernels/jagged_float8_scales.py b/torchao/prototype/moe_training/kernels/jagged_float8_scales.py index 2c19fdc5a2..ff0b11acba 100644 --- a/torchao/prototype/moe_training/kernels/jagged_float8_scales.py +++ b/torchao/prototype/moe_training/kernels/jagged_float8_scales.py @@ -16,8 +16,6 @@ import triton import triton.language as tl -from torchao.prototype.moe_training.utils import _is_column_major - EPS = 1e-12 FP8_DTYPE_MAP = { @@ -33,13 +31,20 @@ torch.float64: tl.float64, } -block_sizes = [128, 256] +block_sizes = [1, 16, 32, 64] +block_sizes_iter = [32, 64, 128, 256] +num_warps = [1, 4] +num_stages = [2, 3] kernel_configs_2D = [ triton.Config( - {"BLOCK_SIZE_ROWS": block_size_rows, "BLOCK_SIZE_COLS": block_size_cols} + {"BLOCK_SIZE": block_size, "BLOCK_SIZE_ITER": block_size_iter}, + num_warps=warps, + num_stages=stages, ) - for block_size_rows in block_sizes - for block_size_cols in block_sizes + for block_size in block_sizes + for block_size_iter in block_sizes_iter + for warps in num_warps + for stages in num_stages ] from torch.library import triton_op, wrap_triton @@ -68,7 +73,6 @@ def triton_fp8_row_major_jagged_rowwise_scales( - jagged rowwise scales (i.e., rowwise scales for each group) """ assert hp_tensor.ndim == 2, "input tensor must be 2D" - assert hp_tensor.is_contiguous(), "input tensor must be contiguous" num_elements = hp_tensor.numel() tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype] @@ -81,16 +85,14 @@ def triton_fp8_row_major_jagged_rowwise_scales( n_groups = offsets.numel() # allocate on-device buffers for output and scales - output_buffer = torch.empty_like( - hp_tensor, dtype=output_dtype, device=hp_tensor.device - ) + output_buffer = torch.empty((m, k), dtype=output_dtype, device=hp_tensor.device) scales_buffer = torch.empty( (m * n_groups), dtype=torch.float32, device=hp_tensor.device ) # parallelize across rows and groups (offsets) grid = lambda meta: ( - triton.cdiv(m, meta["BLOCK_SIZE_ROWS"]), + triton.cdiv(m, meta["BLOCK_SIZE"]), offsets.numel(), ) wrap_triton(_triton_fp8_row_major_jagged_rowwise_scales)[grid]( @@ -115,7 +117,13 @@ def triton_fp8_row_major_jagged_rowwise_scales( return output_buffer, scales_buffer -@triton.autotune(configs=kernel_configs_2D, key=["num_elements"]) +# This kernel is used on grad_output.t() which has shape (K, M), +# before the calculation `grad_B = grad_output_t @ input`. +# However, in this code, we use the conventional dim names (M, K) +# so the kernel is easily interpretable in a standalone fasion. +# The tokens per expert will vary per iteration, so don't want +# to recompile on `token` dim (K, in this case) changes. +@triton.autotune(configs=kernel_configs_2D, key=["M"]) @triton.jit def _triton_fp8_row_major_jagged_rowwise_scales( input_ptr, @@ -134,8 +142,8 @@ def _triton_fp8_row_major_jagged_rowwise_scales( input_dtype: tl.constexpr, output_dtype: tl.constexpr, round_scales_to_power_of_2: tl.constexpr, - BLOCK_SIZE_ROWS: tl.constexpr, - BLOCK_SIZE_COLS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_SIZE_ITER: tl.constexpr, EPS: tl.constexpr, ): # parallel across rows and groups (offsets) @@ -147,12 +155,12 @@ def _triton_fp8_row_major_jagged_rowwise_scales( offsets_ptr + offset_idx - 1, mask=offset_idx > 0, other=0 ) group_col_end_idx = tl.load(offsets_ptr + offset_idx) - block_row_offs = block_row_id * BLOCK_SIZE_ROWS + tl.arange(0, BLOCK_SIZE_ROWS) + block_row_offs = block_row_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) # compute rowwise amaxes for this group - amax_buffer = tl.zeros((BLOCK_SIZE_ROWS,), dtype=input_dtype) - for col_start_idx in range(group_col_start_idx, group_col_end_idx, BLOCK_SIZE_COLS): - block_col_offs = col_start_idx + tl.arange(0, BLOCK_SIZE_COLS) + amax_buffer = tl.zeros((BLOCK_SIZE,), dtype=input_dtype) + for col_start_idx in range(group_col_start_idx, group_col_end_idx, BLOCK_SIZE_ITER): + block_col_offs = col_start_idx + tl.arange(0, BLOCK_SIZE_ITER) block_offs = ( block_row_offs[:, None] * stride_input_row + block_col_offs[None, :] * stride_input_col @@ -180,12 +188,12 @@ def _triton_fp8_row_major_jagged_rowwise_scales( # store rowwise scales for each group in contiguous memory: # [group0_row0, group_0_row1, ..., group2_row0, group2_row1] scales_offs = block_row_offs + (M * offset_idx) - scales_mask = tl.arange(0, BLOCK_SIZE_ROWS) < M + scales_mask = tl.arange(0, BLOCK_SIZE) < M tl.store(scales_ptr + scales_offs, scales, mask=scales_mask) # perform float8 conversion for this group - for col_start_idx in range(group_col_start_idx, group_col_end_idx, BLOCK_SIZE_COLS): - block_col_offs = col_start_idx + tl.arange(0, BLOCK_SIZE_COLS) + for col_start_idx in range(group_col_start_idx, group_col_end_idx, BLOCK_SIZE_ITER): + block_col_offs = col_start_idx + tl.arange(0, BLOCK_SIZE_ITER) block_offs = ( block_row_offs[:, None] * stride_input_row + block_col_offs[None, :] * stride_input_col @@ -230,7 +238,6 @@ def triton_fp8_col_major_jagged_colwise_scales( - jagged column-wise scales (i.e., column-wise scales for each group) """ assert hp_tensor.ndim == 2, "input tensor must be 2D" - assert _is_column_major(hp_tensor), "input tensor must be column-major" num_elements = hp_tensor.numel() tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype] @@ -242,17 +249,18 @@ def triton_fp8_col_major_jagged_colwise_scales( k, n = hp_tensor.shape n_groups = offsets.numel() - # allocate on-device buffers for output and scales + # Output buffer in column major output_buffer = torch.empty_like( hp_tensor, dtype=output_dtype, device=hp_tensor.device - ) + ).as_strided(hp_tensor.size(), (1, k)) + scales_buffer = torch.empty( (n * n_groups), dtype=torch.float32, device=hp_tensor.device ) # parallelize across columns and groups (offsets) grid = lambda meta: ( - triton.cdiv(n, meta["BLOCK_SIZE_COLS"]), + triton.cdiv(n, meta["BLOCK_SIZE"]), offsets.numel(), ) wrap_triton(_triton_fp8_col_major_jagged_colwise_scales)[grid]( @@ -277,7 +285,11 @@ def triton_fp8_col_major_jagged_colwise_scales( return output_buffer, scales_buffer -@triton.autotune(configs=kernel_configs_2D, key=["num_elements"]) +# This kernel is used on `input` which has shape (M, K), +# before the calculation `grad_B = grad_output_t @ input`. +# The tokens per expert will vary per iteration, so don't want +# to recompile on `token` dim (M) changes. +@triton.autotune(configs=kernel_configs_2D, key=["K"]) @triton.jit def _triton_fp8_col_major_jagged_colwise_scales( input_ptr, @@ -296,8 +308,8 @@ def _triton_fp8_col_major_jagged_colwise_scales( input_dtype: tl.constexpr, output_dtype: tl.constexpr, round_scales_to_power_of_2: tl.constexpr, - BLOCK_SIZE_ROWS: tl.constexpr, - BLOCK_SIZE_COLS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_SIZE_ITER: tl.constexpr, EPS: tl.constexpr, ): # parallel across columns and groups (offsets) @@ -309,12 +321,12 @@ def _triton_fp8_col_major_jagged_colwise_scales( offsets_ptr + offset_idx - 1, mask=offset_idx > 0, other=0 ) group_row_end_idx = tl.load(offsets_ptr + offset_idx) - block_col_offs = block_col_id * BLOCK_SIZE_COLS + tl.arange(0, BLOCK_SIZE_COLS) + block_col_offs = block_col_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) # compute colwise amaxes for this group - amax_buffer = tl.zeros((BLOCK_SIZE_COLS,), dtype=input_dtype) - for row_start_idx in range(group_row_start_idx, group_row_end_idx, BLOCK_SIZE_ROWS): - block_row_offs = row_start_idx + tl.arange(0, BLOCK_SIZE_ROWS) + amax_buffer = tl.zeros((BLOCK_SIZE,), dtype=input_dtype) + for row_start_idx in range(group_row_start_idx, group_row_end_idx, BLOCK_SIZE_ITER): + block_row_offs = row_start_idx + tl.arange(0, BLOCK_SIZE_ITER) block_offs = ( block_row_offs[:, None] * stride_input_row + block_col_offs[None, :] * stride_input_col @@ -343,12 +355,12 @@ def _triton_fp8_col_major_jagged_colwise_scales( # [group0_col0, group_0_col1, ..., group2_col0, group2_col1] # note: input tensor is in col-major memory layout. scales_offs = block_col_offs + (N * offset_idx) - scales_mask = tl.arange(0, BLOCK_SIZE_COLS) < N + scales_mask = tl.arange(0, BLOCK_SIZE) < N tl.store(scales_ptr + scales_offs, scales, mask=scales_mask) # perform float8 conversion for this group - for row_start_idx in range(group_row_start_idx, group_row_end_idx, BLOCK_SIZE_ROWS): - block_row_offs = row_start_idx + tl.arange(0, BLOCK_SIZE_ROWS) + for row_start_idx in range(group_row_start_idx, group_row_end_idx, BLOCK_SIZE_ITER): + block_row_offs = row_start_idx + tl.arange(0, BLOCK_SIZE_ITER) block_offs = ( block_row_offs[:, None] * stride_input_row + block_col_offs[None, :] * stride_input_col diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index fd22186939..f06b47d4cc 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -217,19 +217,14 @@ def backward(ctx, grad_output: torch.Tensor): use_fast_accum=True, ) - # Convert transpose of grad_output to float8, row-major for left operand of grouped GEMM - # needed for grad_B: grad_output_t @ A - grad_output_t_row_major = grad_output.transpose(-2, -1).contiguous() - - # Convert A to float8, column-major for right operand of grouped GEMM: - # needed for grad_B: grad_output @ A - A_col_major = A.transpose(-2, -1).contiguous().transpose(-2, -1) - # grad_B is a special case. both operands of the grouped gemm will be 2D with offsets determing the "groups." # Compute scales for grad_output_t and A, which are both 2D tensors with offsets which define the "jagged" groups. + + # Convert transpose of grad_output to float8, row-major for left operand of grouped GEMM + # needed for grad_B: grad_output_t @ A grad_output_t_fp8_row_major, grad_output_t_scales = ( triton_fp8_row_major_jagged_rowwise_scales( - grad_output_t_row_major, + grad_output.transpose(-2, -1), offs, torch.float8_e4m3fn, round_scales_to_power_of_2=True, @@ -237,7 +232,7 @@ def backward(ctx, grad_output: torch.Tensor): ) A_fp8_col_major, A_scales = triton_fp8_col_major_jagged_colwise_scales( - A_col_major, + A, offs, torch.float8_e4m3fn, round_scales_to_power_of_2=True, @@ -245,7 +240,6 @@ def backward(ctx, grad_output: torch.Tensor): # Compute grad_B = grad_output_t @ A. # grad_B = grad_output_t @ A - # grad_B = (N,M) @ (M,K) = (N,K) assert not _is_column_major(grad_output_t_fp8_row_major), ( "grad_output_t must be row-major for grad_B = grad_output_t @ A" )