Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import pytest
import torch
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck

from torchao.prototype.mx_formats.config import MXGemmKernelChoice
from torchao.prototype.mx_formats.constants import (
Expand Down Expand Up @@ -284,3 +286,25 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
use_fp4_custom_triton_dequant_kernel,
)
torch.testing.assert_close(x_mx_dq, x_mx_c_dq, atol=0, rtol=0)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
@pytest.mark.skipif(
not is_sm_at_least_89(),
reason="float8 in triton requires CUDA capability 8.9 or greater",
)
def test_to_mx_inductor_single_kernel():
"""
Verify that inductor can fuse the cast of a high precision tensor to mx
into a single kernel
"""
# TODO(future PR): add fp4 and fp6 here
# TODO(#1773): add swizzled scale format here
x = torch.randn(2048, 2048, dtype=torch.bfloat16, device="cuda")
block_size = 32
to_mx_c = torch.compile(MXTensor.to_mx, fullgraph=True)
out, code = run_and_get_code(to_mx_c, x, torch.float8_e4m3fn, block_size)
FileCheck().check("def call(").check_count(".run(", 1, exactly=True).run(code[0])
11 changes: 10 additions & 1 deletion torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,16 +205,25 @@ def to_mx(
data_lp = torch.clamp(
data_hp / scale_fp.unsqueeze(1), min=-1 * max_pos, max=max_pos
)
data_lp = data_lp.reshape(orig_shape)

# cast to target dtype
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
data_lp = data_lp.to(elem_dtype)
# need to reshape at the end to help inductor fuse things
data_lp = data_lp.reshape(orig_shape)
elif elem_dtype == DTYPE_FP6_E2M3:
data_lp = f32_to_f6_e2m3_unpacked(data_lp)
# need to reshape at the end to help inductor fuse things
data_lp = data_lp.reshape(orig_shape)
elif elem_dtype == DTYPE_FP6_E3M2:
data_lp = f32_to_f6_e3m2_unpacked(data_lp)
# need to reshape at the end to help inductor fuse things
data_lp = data_lp.reshape(orig_shape)
elif elem_dtype == DTYPE_FP4:
# can't reshape at the end without handling it in the packing code,
# punt until later since we'll need to rethink the torch.compile
# approach for fp4x2 in any case
data_lp = data_lp.reshape(orig_shape)
data_lp = f32_to_f4_unpacked(data_lp)
data_lp = pack_uint4(data_lp)
else:
Expand Down
Loading