Skip to content

Add mx_fp4_kernel #1661

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 31 additions & 59 deletions test/prototype/mx_formats/test_mx_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,63 +2,45 @@
import torch

from torchao.float8.float8_utils import compute_error
from torchao.ops import mx_fp8_bf16
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.ops import mx_fp4_bf16, mx_fp8_bf16
from torchao.prototype.mx_formats.mx_tensor import DTYPE_FP4, MXTensor
from torchao.prototype.mx_formats.utils import to_blocked
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_4,
is_sm_at_least_100,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_100

if not TORCH_VERSION_AT_LEAST_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


def run_matrix_test(M: int, K: int, N: int) -> float:
"""
Run matrix multiplication test with given dimensions.

Args:
M, K, N: Matrix dimensions

Returns:
float: SQNR (Signal-to-Quantization-Noise Ratio) value
"""
def run_matrix_test(M: int, K: int, N: int, format) -> float:
dtype = torch.bfloat16
device = torch.device("cuda")

# Initialize matrices
a = torch.rand((M, K), dtype=dtype, device=device)
b = torch.rand((N, K), dtype=dtype, device=device)

# Convert to MX format
a_mx = MXTensor.to_mx(a, torch.float8_e4m3fn, 32)
b_mx = MXTensor.to_mx(b, torch.float8_e4m3fn, 32)
fmt = torch.float8_e4m3fn if format == "fp8" else DTYPE_FP4
mx_func = mx_fp8_bf16 if format == "fp8" else mx_fp4_bf16

a_fp8 = a_mx._data
b_fp8 = b_mx._data
assert b_fp8.is_contiguous()
b_fp8 = b_fp8.transpose(-1, -2)
a_mx = MXTensor.to_mx(a, fmt, 32)
b_mx = MXTensor.to_mx(b, fmt, 32)

# Get scales
a_scale_e8 = a_mx._scale_e8m0.view(M, K // 32)
b_scale_e8 = b_mx._scale_e8m0.view(N, K // 32)
a_data = a_mx._data
b_data = b_mx._data
assert b_data.is_contiguous()
b_data = b_data.transpose(-1, -2)

a_scale_block = to_blocked(a_scale_e8)
b_scale_block = to_blocked(b_scale_e8)
a_scale = a_mx._scale_e8m0.view(M, K // 32)
b_scale = b_mx._scale_e8m0.view(N, K // 32)

a_scale_block = to_blocked(a_scale)
b_scale_block = to_blocked(b_scale)

# Get reference output
out_hp = a_mx.to_dtype(torch.bfloat16) @ b_mx.to_dtype(torch.bfloat16).transpose(
-1, -2
)
out = mx_func(a_data, b_data, a_scale_block, b_scale_block)

# Run implementation
out_e8_fp8 = mx_fp8_bf16(a_fp8, b_fp8, a_scale_block, b_scale_block)

# Calculate metrics
sqnr = compute_error(out_hp, out_e8_fp8)

return sqnr.item()
return compute_error(out_hp, out).item()


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand All @@ -68,35 +50,25 @@ def run_matrix_test(M: int, K: int, N: int) -> float:
@pytest.mark.parametrize(
"size",
[
# Small matrices
(128, 128, 128),
(256, 256, 256),
(384, 384, 384),
# Medium matrices
(384, 384, 384), # Small
(512, 512, 512),
(640, 640, 640),
(768, 768, 768),
# Large matrices
(896, 896, 896),
(768, 768, 768), # Medium
(1024, 1024, 1024),
# Very large matrices
(8192, 8192, 8192),
# Non-square matrices
(8192, 8192, 8192), # Large
(128, 256, 384),
(256, 384, 512),
(384, 512, 640),
# Non-aligned matrices
(256, 384, 512), # Non-square
(129, 256, 384),
(256, 384, 536),
(133, 512, 528),
(133, 512, 528), # Non-aligned
],
ids=lambda x: f"{x[0]}x{x[1]}x{x[2]}",
)
def test_matrix_multiplication(size):
"""
Test matrix multiplication with various dimensions.
Verifies that the SQNR meets minimum quality threshold.
"""
@pytest.mark.parametrize("format", ["fp8", "fp4"])
def test_matrix_multiplication(size, format):
M, K, N = size
sqnr = run_matrix_test(M, K, N)
assert sqnr >= 80.0, f"SQNR {sqnr} below threshold for dims {M}x{K}x{N}"
sqnr = run_matrix_test(M, K, N, format)
threshold = 80.0
assert (
sqnr >= threshold
), f"{format} SQNR {sqnr} below threshold for dims {M}x{K}x{N}"
50 changes: 42 additions & 8 deletions torchao/csrc/cuda/mx_kernels/mx_fp_cutlass_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ using namespace cute;

template<typename Element>
constexpr int GetAlignment() {
if constexpr (std::is_same_v<Element, cutlass::nv_float4_t<cutlass::float_e2m1_t>>)
if constexpr (std::is_same_v<Element, cutlass::mx_float4_t<cutlass::float_e2m1_t>>)
return 32;
return 16;
}
Expand All @@ -46,11 +46,7 @@ template <typename ElementA,
typename ClusterShape,
typename PerSmTileShape_MNK>
void run_gemm(at::Tensor& a, at::Tensor& b, at::Tensor& a_scale,
at::Tensor& b_scale, at::Tensor& out) {
int M = a.size(0);
int K = a.size(1);
int N = b.size(1);

at::Tensor& b_scale, at::Tensor& out, int M, int K, int N) {
// A matrix configuration
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = GetAlignment<ElementA>(); // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
Expand Down Expand Up @@ -225,9 +221,12 @@ at::Tensor mx_fp8_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale,
at::Tensor b_scale) {
#if defined(BUILD_MX_KERNELS_CUTLASS)
validate(a, b, a_scale, b_scale);
auto M = a.size(0);
auto K = a.size(1);
auto N = b.size(1);

auto out =
at::empty({a.size(0), b.size(1)}, a.options().dtype(at::kBFloat16));
at::empty({M, N}, a.options().dtype(at::kBFloat16));
using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
using ElementB = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
using ElementD = cutlass::bfloat16_t;
Expand All @@ -236,16 +235,51 @@ at::Tensor mx_fp8_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale,
using ClusterShape = Shape<_2,_1,_1>;
using PerSmTileShape_MNK = Shape<_128,_128,_128>;

run_gemm<ElementA, ElementB, ElementD, MmaTileShape, ClusterShape, PerSmTileShape_MNK>(a, b, a_scale, b_scale, out);
run_gemm<ElementA, ElementB, ElementD, MmaTileShape, ClusterShape, PerSmTileShape_MNK>(a, b, a_scale, b_scale, out, M, K, N);
return out;
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
return at::Tensor{};
#endif
}

at::Tensor mx_fp4_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale,
at::Tensor b_scale) {
#if defined(BUILD_MX_KERNELS_CUTLASS)
TORCH_CHECK(a.is_cuda(), "a must be CUDA tensor");
TORCH_CHECK(b.is_cuda(), "b must be CUDA tensor");
TORCH_CHECK(a_scale.is_cuda(), "a_scale must be CUDA tensor");
TORCH_CHECK(b_scale.is_cuda(), "b_scale must be CUDA tensor");

auto M = a.size(0);
auto K = a.size(1) * 2;
auto N = b.size(1);

auto out =
at::empty({M, N}, a.options().dtype(at::kBFloat16));
using ElementA = cutlass::mx_float4_t<cutlass::float_e2m1_t>;
using ElementB = cutlass::mx_float4_t<cutlass::float_e2m1_t>;
using ElementD = cutlass::bfloat16_t;

using MmaTileShape = Shape<_128,_128,_128>;
using ClusterShape = Shape<_2,_1,_1>;
using PerSmTileShape_MNK = Shape<_128,_128,_128>;

run_gemm<ElementA, ElementB, ElementD, MmaTileShape, ClusterShape, PerSmTileShape_MNK>(a, b, a_scale, b_scale, out, M, K, N);
return out;
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
return at::Tensor{};
#endif
}

TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
m.impl("torchao::mx_fp8_bf16", &mx_fp8_bf16);
}
TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
m.impl("torchao::mx_fp4_bf16", &mx_fp4_bf16);
}



} // namespace torchao
32 changes: 32 additions & 0 deletions torchao/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"rowwise_scaled_linear_cutlass_s8s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor"
)
lib.define("mx_fp8_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor")
lib.define("mx_fp4_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor")


def register_custom_op(name):
Expand Down Expand Up @@ -640,3 +641,34 @@ def mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
def meta_mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
"""Meta impl for mx_fp8_bf16"""
return torch.empty((A.size(0), B.size(1)), dtype=torch.bfloat16, device=A.device)


def mx_fp4_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
"""Defines a matmul between two fp4 tensors w/ MX scales in E8MO and returns a bf16 tensor.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: clarify that it's fp4_e2m1? Maybe link to https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf section 5.3.3?


The expected format is fp4_e2m1 specified:
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final.pdf (Section 5.3.3)

Note: The mx scales are E8MO tensors stored in uint8 tensors (for now).
The layout of the scales is very particular, see:
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout


Args:
A: fp4 tensor (2 fp4 elements are packed into 1 byte -> elem0|elem1)
B: fp4 tensor (2 fp4 elements are packed into 1 byte -> elem0|elem1)
A_scale: E8M0 scale tensor for A with groupsize=32 in swizzled layout
B_scale: E8M0 scale tensor for B with groupsize=32 in swizzled layout

Returns:
MXN bf16 Tensor

"""
return torch.ops.torchao.mx_fp4_bf16.default(A, B, A_scale, B_scale)


@register_custom_op("torchao::mx_fp4_bf16")
def meta_mx_fp4_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
"""Meta impl for mx_fp4_bf16"""
# Assume that the contraction happens in the K dim thus M,N are perserved post bit pack
return torch.empty((A.size(0), B.size(1)), dtype=torch.bfloat16, device=A.device)
Loading