Skip to content

[scaled grouped mm] integrate triton kernels into differentiable scaled grouped mm #2077

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,27 @@
import pytest
import torch

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

# We need to skip before doing any imports which would use triton, since
# triton won't be available on CPU builds and torch < 2.5
if not (
TORCH_VERSION_AT_LEAST_2_5
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 9
):
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
triton_fp8_col_major_jagged_colwise_scales,
triton_fp8_row_major_jagged_rowwise_scales,
)
from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import (
from torchao.prototype.scaled_grouped_mm.utils import (
_is_column_major,
_to_2d_jagged_float8_tensor_colwise,
_to_2d_jagged_float8_tensor_rowwise,
)
from torchao.prototype.scaled_grouped_mm.utils import _is_column_major


@pytest.mark.parametrize("round_scales_to_power_of_2", [True, False])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@
import pytest
import torch

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

# We need to skip before doing any imports which would use triton, since
# triton won't be available on CPU builds and torch < 2.5
if not (
TORCH_VERSION_AT_LEAST_2_5
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 9
):
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


from torchao.float8.config import (
Float8LinearConfig,
Float8LinearRecipeName,
Expand All @@ -19,7 +31,6 @@
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_valid_scaled_grouped_mm_2d_3d():
out_dtype = torch.bfloat16
device = "cuda"
Expand Down Expand Up @@ -73,7 +84,6 @@ def test_valid_scaled_grouped_mm_2d_3d():
assert torch.equal(b_t.grad, ref_b_t.grad)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("m", [16, 17])
@pytest.mark.parametrize("k", [16, 18])
@pytest.mark.parametrize("n", [32, 33])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
triton_fp8_col_major_jagged_colwise_scales,
triton_fp8_row_major_jagged_rowwise_scales,
)
from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import (
from torchao.prototype.scaled_grouped_mm.utils import (
_to_2d_jagged_float8_tensor_colwise,
_to_2d_jagged_float8_tensor_rowwise,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py

import itertools
import time
from dataclasses import dataclass
from typing import List

import torch
from tabulate import tabulate
from tqdm import tqdm

from torchao.prototype.scaled_grouped_mm import _scaled_grouped_mm

device = torch.device("cuda")

# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 1000


@dataclass(frozen=True)
class ExperimentConfig:
high_precision_dtype: torch.dtype
A_shape: tuple[int]
B_shape: tuple[int]


@dataclass(frozen=True)
class ExperimentResult:
time_us: float


@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
result: ExperimentResult


def get_configs() -> List[ExperimentConfig]:
A_shapes = [(2**8, 4096), (2**12, 4096), (2**16, 4096)]
B_shapes = [(4, 4096, 4096), (8, 4096, 4096), (16, 4096, 4096)]
high_precision_dtypes = [torch.bfloat16]
configs = []
for A_shape, B_shape, high_precision_dtype in itertools.product(
A_shapes, B_shapes, high_precision_dtypes
):
configs.append(
ExperimentConfig(
A_shape=A_shape,
B_shape=B_shape,
high_precision_dtype=high_precision_dtype,
)
)
return configs


def run_experiment(config: ExperimentConfig) -> ExperimentResult:
# define test inputs
A = torch.randn(
*config.A_shape,
dtype=config.high_precision_dtype,
device=device,
requires_grad=True,
)
B_t = torch.randn(
*config.B_shape,
dtype=config.high_precision_dtype,
device=device,
requires_grad=True,
).transpose(-2, -1)

# - configure input to be row-major with groups divided along the column dimension,
# representing the left operand of grad_weight = grad_output_t @ input
# that occurs in the backward pass of the differentiable scaled grouped mm.
# - the transposed tensor in col-major format with groups along the row dimension,
# which represents the right operand.
n_groups = config.B_shape[0]
group_size = A.shape[0] // n_groups
offs = torch.arange(
group_size,
group_size * n_groups + 1,
group_size,
device=device,
dtype=torch.int32,
)

def warmup(func, *args, **kwargs):
for _ in range(10):
func(*args, **kwargs)

def forward_backward(A, B_t, offs):
out = _scaled_grouped_mm(A, B_t, offs=offs, out_dtype=torch.bfloat16)
out.sum().backward()

# bench triton
warmup(forward_backward, A, B_t, offs)
start_time_ns = time.perf_counter_ns()
forward_backward(A, B_t, offs)
time_ns = time.perf_counter_ns() - start_time_ns
time_us = time_ns / 1e3

return ExperimentResult(time_us=time_us)


def print_results(experiments: List[Experiment]):
headers = [
"A_shape",
"B_shape",
"high_precision_dtype",
"time_us",
]
rows = []
for experiment in experiments:
A_shape = f"({experiment.config.A_shape[0]}, {experiment.config.A_shape[1]})"
B_shape = f"({experiment.config.B_shape[0]}, {experiment.config.B_shape[1]}, {experiment.config.B_shape[2]})"
rows.append(
[
A_shape,
B_shape,
experiment.config.high_precision_dtype,
experiment.result.time_us,
]
)
print(tabulate(rows, headers=headers))


def main():
torch.random.manual_seed(123)
configs = get_configs()
results = []
for config in tqdm(configs):
result = run_experiment(config)
results.append(Experiment(config=config, result=result))

# Use Tabulate to print results
print_results(results)


if __name__ == "__main__":
main()
6 changes: 6 additions & 0 deletions torchao/prototype/scaled_grouped_mm/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
triton_fp8_col_major_jagged_colwise_scales as triton_fp8_col_major_jagged_colwise_scales,
)
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
triton_fp8_row_major_jagged_rowwise_scales as triton_fp8_row_major_jagged_rowwise_scales,
)
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,11 @@ def _triton_fp8_row_major_jagged_rowwise_scales(
data = tl.load(input_ptr + block_offs, mask=block_mask, other=0.0).to(
input_dtype
)
amax_buffer = tl.maximum(amax_buffer, tl.max(tl.abs(data), axis=1))
# we need to cast back to input dtype since triton promotes bf16 to fp32:
# https://github.com/triton-lang/triton/blob/981e987eed9053b952f81153bc0779c99d8c642e/python/triton/language/standard.py#L173
amax_buffer = tl.maximum(amax_buffer, tl.max(tl.abs(data), axis=1)).to(
input_dtype
)

# compute rowwise scales for this group. round scales to nearest power of 2.
amax_buffer = amax_buffer.to(tl.float64)
Expand Down Expand Up @@ -317,7 +321,11 @@ def _triton_fp8_col_major_jagged_colwise_scales(
data = tl.load(input_ptr + block_offs, mask=block_mask, other=0.0).to(
input_dtype
)
amax_buffer = tl.maximum(amax_buffer, tl.max(tl.abs(data), axis=0))
# we need to cast back to input dtype since triton promotes bf16 to fp32:
# https://github.com/triton-lang/triton/blob/981e987eed9053b952f81153bc0779c99d8c642e/python/triton/language/standard.py#L173
amax_buffer = tl.maximum(amax_buffer, tl.max(tl.abs(data), axis=0)).to(
input_dtype
)

# compute rowwise scales for this group.
amax_buffer = amax_buffer.to(tl.float64)
Expand Down
Loading
Loading