Skip to content

modify cast from hp to mx to help inductor fuse #1786

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import pytest
import torch
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck

from torchao.prototype.mx_formats.config import MXGemmKernelChoice
from torchao.prototype.mx_formats.constants import (
Expand Down Expand Up @@ -284,3 +286,25 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
use_fp4_custom_triton_dequant_kernel,
)
torch.testing.assert_close(x_mx_dq, x_mx_c_dq, atol=0, rtol=0)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
@pytest.mark.skipif(
not is_sm_at_least_89(),
reason="float8 in triton requires CUDA capability 8.9 or greater",
)
def test_to_mx_inductor_single_kernel():
"""
Verify that inductor can fuse the cast of a high precision tensor to mx
into a single kernel
"""
# TODO(future PR): add fp4 and fp6 here
# TODO(#1773): add swizzled scale format here
x = torch.randn(2048, 2048, dtype=torch.bfloat16, device="cuda")
block_size = 32
to_mx_c = torch.compile(MXTensor.to_mx, fullgraph=True)
out, code = run_and_get_code(to_mx_c, x, torch.float8_e4m3fn, block_size)
FileCheck().check("def call(").check_count(".run(", 1, exactly=True).run(code[0])
11 changes: 10 additions & 1 deletion torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,16 +205,25 @@ def to_mx(
data_lp = torch.clamp(
data_hp / scale_fp.unsqueeze(1), min=-1 * max_pos, max=max_pos
)
data_lp = data_lp.reshape(orig_shape)

# cast to target dtype
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
data_lp = data_lp.to(elem_dtype)
# need to reshape at the end to help inductor fuse things
data_lp = data_lp.reshape(orig_shape)
elif elem_dtype == DTYPE_FP6_E2M3:
data_lp = f32_to_f6_e2m3_unpacked(data_lp)
# need to reshape at the end to help inductor fuse things
data_lp = data_lp.reshape(orig_shape)
elif elem_dtype == DTYPE_FP6_E3M2:
data_lp = f32_to_f6_e3m2_unpacked(data_lp)
# need to reshape at the end to help inductor fuse things
data_lp = data_lp.reshape(orig_shape)
elif elem_dtype == DTYPE_FP4:
# can't reshape at the end without handling it in the packing code,
# punt until later since we'll need to rethink the torch.compile
# approach for fp4x2 in any case
data_lp = data_lp.reshape(orig_shape)
data_lp = f32_to_f4_unpacked(data_lp)
data_lp = pack_uint4(data_lp)
else:
Expand Down
Loading