Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions test/prototype/mx_formats/test_custom_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import torchao.prototype.mx_formats.config as config
from torch.utils._triton import has_triton
from torchao.prototype.mx_formats.constants import (
DTYPE_FP4,
DTYPE_FP4_E2M1,
DTYPE_FP4_E3M0,
DTYPE_FP6_E2M3,
DTYPE_FP6_E3M2,
F4_E2M1_EXP_BIAS,
Expand All @@ -20,10 +21,12 @@
)

from torchao.prototype.mx_formats.custom_cast import (
f32_to_f4_unpacked,
f32_to_f4_e2m1_unpacked,
f32_to_f4_e3m0_unpacked,
f32_to_f6_e2m3_unpacked,
f32_to_f6_e3m2_unpacked,
f4_unpacked_to_f32,
f4_e2m1_unpacked_to_f32,
f4_e3m0_unpacked_to_f32,
f6_e2m3_unpacked_to_f32,
f6_e3m2_unpacked_to_f32,
get_bits,
Expand Down Expand Up @@ -189,12 +192,12 @@ def test_float6_e2m3_table():
def _test_fp4_case(f32_val, f32_val_ref, f4_enc_ref):
# 1. verify that a fp32 value gets quantized to correct fp4 encoding
# TODO test on cuda
f4_unpacked = f32_to_f4_unpacked(torch.tensor(f32_val))
f4_unpacked = f32_to_f4_e2m1_unpacked(torch.tensor(f32_val))
s_enc, e_enc, m_enc = get_sem_bits(f4_unpacked, bitwidth=4)
assert s_enc + e_enc + m_enc == f4_enc_ref

# 2. verify that fp4 value gets dequantized to correct fp32 value
f32_dequantized = f4_unpacked_to_f32(f4_unpacked)
f32_dequantized = f4_e2m1_unpacked_to_f32(f4_unpacked)
assert f32_val_ref == f32_dequantized.item()


Expand Down Expand Up @@ -310,11 +313,11 @@ def test_fp4_6_0():

def test_fp4_pack_unpack():
orig_vals = torch.Tensor([[0.0, 0.5, 4.0, -0.0], [-0.0, 1.0, -6.0, 3.0]])
orig_vals_f4_unpacked = f32_to_f4_unpacked(orig_vals)
orig_vals_f4_unpacked = f32_to_f4_e2m1_unpacked(orig_vals)
orig_vals_f4_packed = pack_uint4(orig_vals_f4_unpacked)
assert orig_vals_f4_packed.numel() == (orig_vals.numel() / 2)
orig_vals_f4_packed_unpacked = unpack_uint4(orig_vals_f4_packed)
orig_vals_dq = f4_unpacked_to_f32(orig_vals_f4_packed_unpacked)
orig_vals_dq = f4_e2m1_unpacked_to_f32(orig_vals_f4_packed_unpacked)
assert torch.all(orig_vals_dq == orig_vals)


Expand All @@ -323,7 +326,7 @@ def test_fp4_pack_unpack():
@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_4, reason="requires PyTorch >= 2.4")
def test_fp4_triton_unscaled_cast():
packed_vals = torch.arange(0, 255, dtype=torch.uint8, device="cuda")
f32_ref = f4_unpacked_to_f32(unpack_uint4(packed_vals))
f32_ref = f4_e2m1_unpacked_to_f32(unpack_uint4(packed_vals))
f32_triton = triton_f4_to_bf16(packed_vals).to(torch.float)
assert torch.all(torch.eq(f32_ref, f32_triton))

Expand All @@ -334,7 +337,7 @@ def test_fp4_triton_unscaled_cast():
def test_fp4_triton_scaled_cast():
size = (256,)
orig_vals = torch.randn(size, dtype=torch.float, device="cuda") * 100
mxtensor = MXTensor.to_mx(orig_vals, block_size=32, elem_dtype=DTYPE_FP4)
mxtensor = MXTensor.to_mx(orig_vals, block_size=32, elem_dtype=DTYPE_FP4_E2M1)

f32_ref = mxtensor.to_dtype(torch.float)
config.use_fp4_custom_triton_dequant_kernel = True
Expand Down
188 changes: 188 additions & 0 deletions test/prototype/mx_formats/test_e3m0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import pytest

import torch
from torchao.prototype.mx_formats.constants import (
F32_MIN_NORMAL,
F4_E3M0_EXP_BIAS,
F4_E3M0_MAX,
F4_E3M0_MIN_NORMAL,
)
from torchao.prototype.mx_formats.custom_cast import (
EBITS_F4_E3M0,
f32_to_f4_e3m0_unpacked,
MBITS_F4_E3M0,
)
from torchao.prototype.mx_formats.mx_tensor import MXTensor, to_mx


torch.manual_seed(2)


@pytest.mark.parametrize("hp_dtype", [torch.float32])
@pytest.mark.parametrize("device", ["cuda", "cpu"])
@pytest.mark.parametrize("sign", [1, -1])
@pytest.mark.parametrize("use_stochastic_rounding", [False, True])
def test_overflow_cast(hp_dtype, device, sign, use_stochastic_rounding):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add these tests to test/prototype/mx_formats/test_custom_cast.py to keep the testing of MX numerics in one place?

data_min = sign * F4_E3M0_MAX
data_max = sign * F4_E3M0_MAX * F4_E3M0_MAX
data = (
torch.rand(1024, 1024, dtype=hp_dtype, device=device) * (data_max - data_min)
+ data_min
)

data_lp = f32_to_f4_e3m0_unpacked(data, use_stochastic_rounding)
if sign == 1:
target_lp = torch.full_like(data, 2**EBITS_F4_E3M0 - 1, dtype=torch.uint8)
else:
target_lp = torch.full_like(
data, 2 ** (EBITS_F4_E3M0 + 1) - 1, dtype=torch.uint8
)

torch.testing.assert_close(
data_lp,
target_lp,
atol=0,
rtol=0,
)


@pytest.mark.parametrize("hp_dtype", [torch.float32])
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_underflow_cast(hp_dtype, device):
data_min = -F4_E3M0_MIN_NORMAL
data_max = F4_E3M0_MIN_NORMAL
data = (
torch.rand(1024, 1024, dtype=hp_dtype, device=device) * (data_max - data_min)
+ data_min
)

data_lp = f32_to_f4_e3m0_unpacked(data, use_stochastic_rounding=False)
target_lp = torch.where((data >= 0) & (data <= F4_E3M0_MIN_NORMAL / 2), 0, 1).to(
torch.uint8
)
target_lp = torch.where(
data < -F4_E3M0_MIN_NORMAL / 2, 1 + 2**EBITS_F4_E3M0, data_lp
)

torch.testing.assert_close(
data_lp,
target_lp,
atol=0,
rtol=0,
)


@pytest.mark.parametrize("hp_dtype", [torch.float32])
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_underflow_cast_use_stochastic_rounding(hp_dtype, device):
data_min = -F4_E3M0_MIN_NORMAL
data_max = F4_E3M0_MIN_NORMAL
data = (
torch.rand(1024, 1024, dtype=hp_dtype, device=device) * (data_max - data_min)
+ data_min
)

data_lp = f32_to_f4_e3m0_unpacked(data, use_stochastic_rounding=True)
target_lp = torch.where((data >= 0) & (data <= F4_E3M0_MIN_NORMAL / 2), 0, 1).to(
torch.uint8
)
target_lp = torch.where(
data < -F4_E3M0_MIN_NORMAL / 2, 1 + 2**EBITS_F4_E3M0, data_lp
)


torch.testing.assert_close(
data_lp,
target_lp,
atol=1,
rtol=0,
)

zeros_in_data_lp = (data_lp == 0).sum().item()
zeros_in_target_lp = (target_lp == 0).sum().item()

assert (
zeros_in_data_lp >= zeros_in_target_lp
), f"stochastic rounding should have more non-zero values {zeros_in_data_lp} >= {zeros_in_target_lp}"


@pytest.mark.parametrize("exp_range", list(range(-2, 4)))
@pytest.mark.parametrize("hp_dtype", [torch.float32])
@pytest.mark.parametrize("device", ["cuda", "cpu"])
@pytest.mark.parametrize("sign", [1, -1])
@pytest.mark.parametrize("use_stochastic_rounding", [False, True])
def test_normal_cast(exp_range, hp_dtype, device, sign, use_stochastic_rounding):
if sign == 1:
data_min = pow(2, exp_range)
data_max = pow(2, exp_range + 1)
else:
data_min = - pow(2, exp_range + 1)
data_max = - pow(2, exp_range)

data = (
torch.rand(1024, 1024, dtype=hp_dtype, device=device) * (data_max - data_min)
+ data_min
)

data_lp = f32_to_f4_e3m0_unpacked(data, use_stochastic_rounding).to(torch.float32)
if sign == 1:
data_lp = torch.pow(2, data_lp - F4_E3M0_EXP_BIAS)
else:
data_lp = -torch.pow(2, data_lp - F4_E3M0_EXP_BIAS - 8)

torch.testing.assert_close(
data_lp,
data,
atol=data_max - data_min,
rtol=0,
)


@pytest.mark.parametrize("data_range", [1, 0.75, 0.5, 0.25, 0.125])
@pytest.mark.parametrize("hp_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("device", ["cuda", "cpu"])
@pytest.mark.parametrize("block_size", [32])
@pytest.mark.parametrize("use_stochastic_rounding", [False, True])
def test_mx_qdq(data_range, hp_dtype, block_size, device, use_stochastic_rounding):
data_min = -data_range
data_max = data_range
data = (
torch.rand(1024, 1024, dtype=hp_dtype, device=device) * (data_max - data_min)
+ data_min
)
scale_e8m0_biased, data_lp = to_mx(
data, "fp4_e3m0", block_size, use_stochastic_rounding
)
mx_args = MXTensor(scale_e8m0_biased, data_lp, "fp4_e3m0", block_size, data.dtype)
data_qdq = mx_args.to_dtype(mx_args._orig_dtype)

scale_e8m0_unbiased = scale_e8m0_biased - 127
scale_fp = torch.pow(
torch.full(scale_e8m0_unbiased.size(), 2.0, device=data.device),
scale_e8m0_unbiased,
)
scale_fp = torch.clamp(scale_fp, min=F32_MIN_NORMAL)

data_lp = data.reshape(-1, block_size) / scale_fp.unsqueeze(1)
data_lp = data_lp.reshape(data.shape)

# exclude overflow values whose error is unbounded
saturate_mask = data_lp >= F4_E3M0_MAX
data_qdq = torch.where(saturate_mask, data, data_qdq)

# the largest error equals to max_scale_value * max_exp_range
max_scale_value = torch.max(scale_fp)
largest_error = max_scale_value * (2**4 - 2**3)

torch.testing.assert_close(
data_qdq,
data,
atol=largest_error,
rtol=0,
)
9 changes: 5 additions & 4 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import torch
from torchao.prototype.mx_formats import config
from torchao.prototype.mx_formats.constants import (
DTYPE_FP4,
DTYPE_FP4_E2M1,
DTYPE_FP4_E3M0,
DTYPE_FP6_E2M3,
DTYPE_FP6_E3M2,
SUPPORTED_ELEM_DTYPES,
Expand Down Expand Up @@ -128,7 +129,7 @@ def test_exponent_nan_out(elem_dtype):
data_bits = torch.tensor(
[0, 1, 2, 3, 4, 5], dtype=torch.uint8, device="cuda"
) # noqa: E501
elif elem_dtype == DTYPE_FP4:
elif elem_dtype == DTYPE_FP4_E2M1 or elem_dtype == DTYPE_FP4_E3M0:
data_bits = torch.tensor(
[0, 1, 2, 3, 4, 5], dtype=torch.uint8, device="cuda"
) # noqa: E501
Expand Down Expand Up @@ -164,7 +165,7 @@ def test_block_sizes(elem_dtype):
Smoke test for various block sizes
"""
for B in (1, 2, 32):
if B == 1 and elem_dtype == DTYPE_FP4:
if B == 1 and (elem_dtype == DTYPE_FP4_E2M1 or elem_dtype == DTYPE_FP4_E3M0):
pytest.skip("unsupported configuration")
tensor_hp = torch.randn(B, device="cuda", dtype=torch.bfloat16)
_test_mx(tensor_hp, elem_dtype, B)
Expand All @@ -177,7 +178,7 @@ def test_transpose(elem_dtype, fp4_triton):
"""
Verify that transposing an MX tensor works
"""
if elem_dtype != DTYPE_FP4 and fp4_triton:
if elem_dtype != DTYPE_FP4_E2M1 and fp4_triton:
pytest.skip("unsupported configuration")

tensor_hp = torch.randn(128, 256, device="cuda", dtype=torch.bfloat16)
Expand Down
6 changes: 3 additions & 3 deletions test/prototype/test_fp6_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_to_tc_float6_e3m2_compile(self, device):
x = torch.randn(256, 64, device=device)

expected = to_tc_float6_e3m2(x)
actual = torch.compile(to_tc_float6_e3m2)(x)
actual = torch.compile(to_tc_float6_e3m2, fullgraph=True)(x)
torch.testing.assert_close(actual, expected)

@parametrize("device", _DEVICES)
Expand All @@ -53,7 +53,7 @@ def test_from_tc_float6_e3m2_compile(self, device):
x = torch.randint(256, size=(M, N * 3 // 4), dtype=torch.uint8, device=device)

expected = from_tc_float6_e3m2(x)
actual = torch.compile(from_tc_float6_e3m2)(x)
actual = torch.compile(from_tc_float6_e3m2, fullgraph=True)(x)
torch.testing.assert_close(actual, expected)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand Down Expand Up @@ -81,7 +81,7 @@ def test_fp6_llm_linear_compile(self, bias):

x = torch.randn(N, IC, device=device, dtype=torch.half)
expected = fp6_linear(x)
actual = torch.compile(fp6_linear)(x)
actual = torch.compile(fp6_linear, fullgraph=True)(x)
torch.testing.assert_close(actual, expected)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand Down
Loading