From dc5b10fc4ebda0827ac504f37448a0429bb3343e Mon Sep 17 00:00:00 2001 From: jeromeku Date: Fri, 21 Jun 2024 12:46:41 -0700 Subject: [PATCH 01/19] add unpack cuda --- torchao/csrc/cuda/unpack_int4/unpack_int4.cu | 133 +++++++++++++++++++ torchao/csrc/unpack_int4.cpp | 8 ++ torchao/ops.py | 48 +++++++ 3 files changed, 189 insertions(+) create mode 100644 torchao/csrc/cuda/unpack_int4/unpack_int4.cu create mode 100644 torchao/csrc/unpack_int4.cpp diff --git a/torchao/csrc/cuda/unpack_int4/unpack_int4.cu b/torchao/csrc/cuda/unpack_int4/unpack_int4.cu new file mode 100644 index 0000000000..46233e1070 --- /dev/null +++ b/torchao/csrc/cuda/unpack_int4/unpack_int4.cu @@ -0,0 +1,133 @@ +#include +#include +#include +#include +#include +#include + +template +constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) { + static_assert(std::is_integral::value && std::is_integral::value, ""); + const uint64_t blocks = a / b + (a % b != 0); + return blocks; +} +constexpr int32_t kWarpSize = 32; + +template +__global__ void unpack_m16n8k16_Bint4_layout( + // size [ceil(n / 8)][ceil(k / (InnerKTiles * 16))][32][InnerKTiles / 2] + const at::PackedTensorAccessor32 in, + // size [n][k] + at::PackedTensorAccessor32 out) { + + constexpr int32_t kNTileSize = 8; + constexpr int32_t kKTileSize = 16; + + auto kOuterTile = blockIdx.x; + auto nTile = blockIdx.y; + auto t = threadIdx.x; + + // n dimension that this lane loads from + auto n0 = nTile * kNTileSize + (t / 4); + + // 8 k-tile values, 4 per m16n8k16 mma.sync operand B + int32_t ks[8]; + + // int32_t v[8]; + int32_t v[8]; + + // Store address base offset + auto pOut = &out[n0][0]; + +// Unpack 2 k-tiles at a time since min pack size is InnerKTiles = 2 +#pragma unroll + for (int innerKTile = 0; innerKTile < InnerKTiles; innerKTile += 2) { + + // Offsets of innerTile0 + auto kBase0 = (kOuterTile * InnerKTiles + innerKTile) * kKTileSize; + ks[0] = kBase0 + (t % 4) * 2; + ks[1] = ks[0] + 1; + ks[2] = ks[0] + 8; + ks[3] = ks[0] + 8 + 1; + + // Offsets of innerTile1 + auto kBase1 = kBase0 + kKTileSize; + ks[4] = kBase1 + (t % 4) * 2; + ks[5] = ks[4] + 1; + ks[6] = ks[4] + 8; + ks[7] = ks[4] + 8 + 1; + + // inner k-tiles unpack two at a time + int32_t pack = in[nTile][kOuterTile][t][innerKTile / 2]; + v[0] = pack & 0x0000000f; + v[2] = (pack >> 4) & 0x0000000f; + v[4] = (pack >> 8) & 0x0000000f; + v[6] = (pack >> 12) & 0x0000000f; + v[1] = (pack >> 16) & 0x0000000f; + v[3] = (pack >> 20) & 0x0000000f; + v[5] = (pack >> 24) & 0x0000000f; + v[7] = (pack >> 28) & 0x0000000f; + + // Write out +#pragma unroll + for (int i = 0; i < 8; ++i) { + pOut[ks[i]] = v[i]; + } + } +} + +// output is [n][k] (int32 dtype) +// input is [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2] +at::Tensor unpack_int4_packed( + const at::Tensor& packed_w, + int64_t innerKTiles) +{ + + c10::cuda::CUDAGuard g(packed_w.device()); + + TORCH_CHECK(packed_w.dim() == 4); + TORCH_CHECK(packed_w.dtype() == at::kInt); + TORCH_CHECK(packed_w.is_contiguous()); + + TORCH_CHECK(packed_w.size(2) == 32); + TORCH_CHECK(packed_w.size(3) == innerKTiles / 2); + TORCH_CHECK(innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8); + + int N = packed_w.size(0) * 8; + int K = packed_w.size(1) * innerKTiles * 16; + + constexpr int32_t kNTileSize = 8; + constexpr int32_t kKTileSize = 16; + + auto nTiles = divUp(N, kNTileSize); + + auto kSuperTiles = divUp(K, innerKTiles * kKTileSize); + + auto out = at::empty( + {N, K}, + at::TensorOptions().dtype(at::kInt).device(packed_w.device())); + + auto stream = at::cuda::getCurrentCUDAStream(); + dim3 grid(kSuperTiles, nTiles); + + if (innerKTiles == 2) { + unpack_m16n8k16_Bint4_layout<2><<>>( + packed_w.packed_accessor32(), + out.packed_accessor32()); + } + else if (innerKTiles == 4) { + unpack_m16n8k16_Bint4_layout<4><<>>( + packed_w.packed_accessor32(), + out.packed_accessor32()); + } else if (innerKTiles == 8) { + unpack_m16n8k16_Bint4_layout<8><<>>( + packed_w.packed_accessor32(), + out.packed_accessor32()); + } + + return out; +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::unpack_int4_packed", &unpack_int4_packed); +} diff --git a/torchao/csrc/unpack_int4.cpp b/torchao/csrc/unpack_int4.cpp new file mode 100644 index 0000000000..ccbf08dbcf --- /dev/null +++ b/torchao/csrc/unpack_int4.cpp @@ -0,0 +1,8 @@ +#include +#include +#include + +TORCH_LIBRARY_FRAGMENT(torchao, m) { + m.impl_abstract_pystub("torchao.ops"); + m.def("unpack_int4_packed(Tensor packed_w, int innerKTiles) -> Tensor"); +} diff --git a/torchao/ops.py b/torchao/ops.py index 25cbfb5656..49e5a7f168 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -43,3 +43,51 @@ def _(_in_feats, _weights, _scales, splitK = 1): torch._check(OC == _scales.shape[0], lambda: "Dimensions mismatched") return _in_feats.new_empty((BS, OC)) + + + +def unpack_int4_packed(packed_w: Tensor, innerKTiles: int) -> Tensor: + """ + Unpacks weights that were packed with `torch.ops.aten._convert_weight_to_int4pack` to original tensor of shape `N x K`. + + Assumes that the packed weights were generated with `torch.ops.aten._convert_weight_to_int4pack` with `innerKTiles = 2 | 4 | 8`" + + Args: + packed_w: torch.tensor: 4D tensor with shape (N / 8) x (K / (innerKTiles * 16)) x 32 x innerKTiles, dtype is torch.int32 + innerKTiles: int + + Returns: + torch.tensor of shape is N x K, dtype is torch.int32 + + """ + # return torch.ops.torchao.unpack_int4_packed.default( + # packed_w=packed_w, innerKTiles=innerKTiles + # ) + return torch.ops.ao_unpack.unpack_int4_packed.default( + packed_w=packed_w, innerKTiles=innerKTiles + ) + + +@register_custom_op(f"torchao::unpack_int4_packed") +def _(packed_w: Tensor, innerKTiles: int) -> Tensor: + torch._check( + packed_w.dim() == 4, + lambda: f"packed weight should be a 42d tensor, got {packed_w.dim()}D", + ) + torch._check( + packed_w.dtype is torch.int32, + lambda: f"weight must be INT32, got {packed_w.dtype}", + ) + torch._check( + innerKTiles == 2 or innerKTiles == 4 or innerKTiles == 8, + lambda: "innerKTiles must be 2, 4, or 8", + ) + torch._check(packed_w.size(2) == 32, lambda: "packed weight must have 32 at dim 2") + torch._check( + packed_w.size(3) == innerKTiles / 2, + lambda: "packed weight must have innerKTiles/2 at dim 3", + ) + N = packed_w.size(0) * 8 + K = packed_w.size(1) * innerKTiles * 16 + + return torch.empty((N, K), dtype=torch.int32, device=packed_w.device) From fff3e8a151784cb7f53282efa993d1b4efea67da Mon Sep 17 00:00:00 2001 From: jeromeku Date: Fri, 21 Jun 2024 12:59:46 -0700 Subject: [PATCH 02/19] add tests --- test/test_ops.py | 51 ++++++++++++++++++++++++++++++++++++++++++++++++ torchao/ops.py | 7 ++----- 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 920b32c5f2..1c958e9fe0 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1,3 +1,4 @@ +import itertools import torch from torch.testing._internal.common_utils import TestCase, IS_FBCODE from torch.testing._internal.optests import opcheck @@ -55,6 +56,56 @@ def test_fp6_llm_linear_correctness(self, BS, OC, IC, splitK): relative_error = error / results_fp16.abs() assert relative_error.mean() < 1e-2 +## Tests for `unpack_int4_packed` +kTileSizeN = 8 +kTileSizeK = 16 + +SHAPES = [ + (4096, 4096), + # Llama 2 GEMM shapes + (4096, 11008), + (11008, 4096), + # Llama 3 GEMM shapes + (4096, 14336), + (14336, 4096), +] +INNERKTILES = [2, 4, 8] + +TEST_CONFIGS = list(itertools.product(SHAPES, INNERKTILES)) + +@unittest.skipIf(IS_FBCODE, "Skipping the test in fbcode since we don't have TARGET file for kernels") +@pytest.mark.skipif(not torch.cuda.is_available(), "CUDA not available") +@pytest.mark.parametrize("shape, innerKTiles", TEST_CONFIGS, ids=str) +def test_int4_unpack_correctness(shape, innerKTiles): + N, K = shape + assert K % (innerKTiles * kTileSizeK) == 0 and N % kTileSizeN == 0 + + t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") + packed_w = torch.ops.aten._convert_weight_to_int4pack(t, innerKTiles) + unpacked = torchao.ops.unpack_int4_packed(packed_w, innerKTiles) + assert torch.allclose(t, unpacked) + + +# TODO: Fix "test_aot_dispatch_dynamic" test failure +@unittest.skipIf(IS_FBCODE, "Skipping the test in fbcode since we don't have TARGET file for kernels") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("shape, innerKTiles", TEST_CONFIGS , ids=str) +def test_int4_unpack_op(shape, innerKTiles): + test_utils = [ + "test_schema", + "test_autograd_registration", + "test_faketensor", + # "test_aot_dispatch_dynamic", + ] + t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") + packed_w = torch.ops.aten._convert_weight_to_int4pack(t, innerKTiles) + + opcheck( + torch.ops.torchao.unpack_int4_packed, + (packed_w, innerKTiles), + test_utils=test_utils, + ) + if __name__ == "__main__": unittest.main() diff --git a/torchao/ops.py b/torchao/ops.py index 49e5a7f168..6fa3d811b3 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -60,13 +60,10 @@ def unpack_int4_packed(packed_w: Tensor, innerKTiles: int) -> Tensor: torch.tensor of shape is N x K, dtype is torch.int32 """ - # return torch.ops.torchao.unpack_int4_packed.default( - # packed_w=packed_w, innerKTiles=innerKTiles - # ) - return torch.ops.ao_unpack.unpack_int4_packed.default( + return torch.ops.torchao.unpack_int4_packed.default( packed_w=packed_w, innerKTiles=innerKTiles ) - + @register_custom_op(f"torchao::unpack_int4_packed") def _(packed_w: Tensor, innerKTiles: int) -> Tensor: From 39f23cf6bcdc8cac3e19a261ac2b0b9e620cf0f7 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Fri, 21 Jun 2024 14:13:18 -0700 Subject: [PATCH 03/19] fix tests --- test/test_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 1c958e9fe0..b318ac699b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -73,8 +73,8 @@ def test_fp6_llm_linear_correctness(self, BS, OC, IC, splitK): TEST_CONFIGS = list(itertools.product(SHAPES, INNERKTILES)) -@unittest.skipIf(IS_FBCODE, "Skipping the test in fbcode since we don't have TARGET file for kernels") -@pytest.mark.skipif(not torch.cuda.is_available(), "CUDA not available") +@pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("shape, innerKTiles", TEST_CONFIGS, ids=str) def test_int4_unpack_correctness(shape, innerKTiles): N, K = shape @@ -87,7 +87,7 @@ def test_int4_unpack_correctness(shape, innerKTiles): # TODO: Fix "test_aot_dispatch_dynamic" test failure -@unittest.skipIf(IS_FBCODE, "Skipping the test in fbcode since we don't have TARGET file for kernels") +@pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("shape, innerKTiles", TEST_CONFIGS , ids=str) def test_int4_unpack_op(shape, innerKTiles): From e41b68292d4613bd97f68f9e67dd76013e59d235 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 23 Jun 2024 14:39:10 -0700 Subject: [PATCH 04/19] refactor tinygemm unpacking kernel --- test/test_ops.py | 4 +- torchao/csrc/cuda/unpack_int4/unpack_int4.cu | 108 ++++++++++++------- torchao/csrc/unpack_int4.cpp | 2 +- torchao/ops.py | 6 +- 4 files changed, 77 insertions(+), 43 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index b318ac699b..a3f2cba07d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -82,7 +82,7 @@ def test_int4_unpack_correctness(shape, innerKTiles): t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") packed_w = torch.ops.aten._convert_weight_to_int4pack(t, innerKTiles) - unpacked = torchao.ops.unpack_int4_packed(packed_w, innerKTiles) + unpacked = torchao.ops.unpack_int4_to_int(packed_w, innerKTiles) assert torch.allclose(t, unpacked) @@ -101,7 +101,7 @@ def test_int4_unpack_op(shape, innerKTiles): packed_w = torch.ops.aten._convert_weight_to_int4pack(t, innerKTiles) opcheck( - torch.ops.torchao.unpack_int4_packed, + torch.ops.torchao.unpack_int4_to_int, (packed_w, innerKTiles), test_utils=test_utils, ) diff --git a/torchao/csrc/cuda/unpack_int4/unpack_int4.cu b/torchao/csrc/cuda/unpack_int4/unpack_int4.cu index 46233e1070..22cec81f76 100644 --- a/torchao/csrc/cuda/unpack_int4/unpack_int4.cu +++ b/torchao/csrc/cuda/unpack_int4/unpack_int4.cu @@ -13,12 +13,15 @@ constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) { } constexpr int32_t kWarpSize = 32; -template -__global__ void unpack_m16n8k16_Bint4_layout( - // size [ceil(n / 8)][ceil(k / (InnerKTiles * 16))][32][InnerKTiles / 2] +// in size [ceil(n / 8)][ceil(k / (InnerKTiles * 16))][32][InnerKTiles / 2] +// +// out size [n][k] +template +__global__ void _dequantize_int4_kernel( const at::PackedTensorAccessor32 in, - // size [n][k] - at::PackedTensorAccessor32 out) { + at::PackedTensorAccessor32 out, + at::optional> scales_and_zeros = c10::nullopt) +{ constexpr int32_t kNTileSize = 8; constexpr int32_t kKTileSize = 16; @@ -31,54 +34,85 @@ __global__ void unpack_m16n8k16_Bint4_layout( auto n0 = nTile * kNTileSize + (t / 4); // 8 k-tile values, 4 per m16n8k16 mma.sync operand B - int32_t ks[8]; - - // int32_t v[8]; - int32_t v[8]; + // int32_t ks[8]; + //Only need 4 offsets since TC layout for single tile is 2x2 (2 pairs of 2 contiguous values) + int32_t ks[4]; // Store address base offset auto pOut = &out[n0][0]; - + // Unpack 2 k-tiles at a time since min pack size is InnerKTiles = 2 #pragma unroll for (int innerKTile = 0; innerKTile < InnerKTiles; innerKTile += 2) { - + //Tensor-core layout for m16n8k16 is such that each tile has 2 pairs of 2 contiguous values + //Hence, we only need 4 offsets // Offsets of innerTile0 auto kBase0 = (kOuterTile * InnerKTiles + innerKTile) * kKTileSize; ks[0] = kBase0 + (t % 4) * 2; - ks[1] = ks[0] + 1; - ks[2] = ks[0] + 8; - ks[3] = ks[0] + 8 + 1; + ks[1] = ks[0] + 8; // Offsets of innerTile1 auto kBase1 = kBase0 + kKTileSize; - ks[4] = kBase1 + (t % 4) * 2; - ks[5] = ks[4] + 1; - ks[6] = ks[4] + 8; - ks[7] = ks[4] + 8 + 1; + ks[2] = kBase1 + (t % 4) * 2; + ks[3] = ks[2] + 8; // inner k-tiles unpack two at a time int32_t pack = in[nTile][kOuterTile][t][innerKTile / 2]; - v[0] = pack & 0x0000000f; - v[2] = (pack >> 4) & 0x0000000f; - v[4] = (pack >> 8) & 0x0000000f; - v[6] = (pack >> 12) & 0x0000000f; - v[1] = (pack >> 16) & 0x0000000f; - v[3] = (pack >> 20) & 0x0000000f; - v[5] = (pack >> 24) & 0x0000000f; - v[7] = (pack >> 28) & 0x0000000f; - - // Write out -#pragma unroll - for (int i = 0; i < 8; ++i) { - pOut[ks[i]] = v[i]; - } + + if constexpr(kDequant) { + // static_assert(scales_and_zeros.has_value(), "scales_and_zeros must be set when dequantizing"); + static_assert(std::is_same::value, "Out must be BFloat16 when dequantizing"); + __nv_bfloat16 v[8]; + + v[0] = __int2bfloat16_rn(pack & 0x0000000f); + v[2] = __int2bfloat16_rn((pack >> 4) & 0x0000000f); + v[4] = __int2bfloat16_rn((pack >> 8) & 0x0000000f); + v[6] = __int2bfloat16_rn((pack >> 12) & 0x0000000f); + v[1] = __int2bfloat16_rn((pack >> 16) & 0x0000000f); + v[3] = __int2bfloat16_rn((pack >> 20) & 0x0000000f); + v[5] = __int2bfloat16_rn((pack >> 24) & 0x0000000f); + v[7] = __int2bfloat16_rn((pack >> 28) & 0x0000000f); + + // All b values within a 16x16 tile should fall within the same q group + // Hence we load 1 scale and zero per loop + int qgroup = ks[0] / groupSize; + const __nv_bfloat16 *pSZ = reinterpret_cast(&scales_and_zeros.value()[qgroup][n0][0]); + + //Reinterpret as pairs of v as pairs of bfloat16 + __nv_bfloat162 *v_bf16x2 = reinterpret_cast<__nv_bfloat162*>(v); + __nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]); + __nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]); + + #pragma unroll + for (int i = 0; i < 4; i++) { + reinterpret_cast<__nv_bfloat162*>(&pOut[ks[i]])[0] = __hmul2(scale2, __hsub2(v_bf16x2[i], zero2));; + } + } + else { + static_assert(std::is_same::value, "Out must be int32_t when unpacking to int"); + int32_t v[8]; + + v[0] = pack & 0x0000000f; + v[2] = (pack >> 4) & 0x0000000f; + v[4] = (pack >> 8) & 0x0000000f; + v[6] = (pack >> 12) & 0x0000000f; + v[1] = (pack >> 16) & 0x0000000f; + v[3] = (pack >> 20) & 0x0000000f; + v[5] = (pack >> 24) & 0x0000000f; + v[7] = (pack >> 28) & 0x0000000f; + int2* v_i32x2 = reinterpret_cast(v); + + #pragma unroll + for (int i = 0; i < 4; ++i) { + reinterpret_cast(&pOut[ks[i]])[0] = v_i32x2[i]; + } + } } } // output is [n][k] (int32 dtype) // input is [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2] -at::Tensor unpack_int4_packed( +at::Tensor _unpack_int4_to_int( const at::Tensor& packed_w, int64_t innerKTiles) { @@ -111,16 +145,16 @@ at::Tensor unpack_int4_packed( dim3 grid(kSuperTiles, nTiles); if (innerKTiles == 2) { - unpack_m16n8k16_Bint4_layout<2><<>>( + _dequantize_int4_kernel<<>>( packed_w.packed_accessor32(), out.packed_accessor32()); } else if (innerKTiles == 4) { - unpack_m16n8k16_Bint4_layout<4><<>>( + _dequantize_int4_kernel<<>>( packed_w.packed_accessor32(), out.packed_accessor32()); } else if (innerKTiles == 8) { - unpack_m16n8k16_Bint4_layout<8><<>>( + _dequantize_int4_kernel<<>>( packed_w.packed_accessor32(), out.packed_accessor32()); } @@ -129,5 +163,5 @@ at::Tensor unpack_int4_packed( } TORCH_LIBRARY_IMPL(torchao, CUDA, m) { - m.impl("torchao::unpack_int4_packed", &unpack_int4_packed); + m.impl("torchao::unpack_int4_to_int", &_unpack_int4_to_int); } diff --git a/torchao/csrc/unpack_int4.cpp b/torchao/csrc/unpack_int4.cpp index ccbf08dbcf..1c427dc51a 100644 --- a/torchao/csrc/unpack_int4.cpp +++ b/torchao/csrc/unpack_int4.cpp @@ -4,5 +4,5 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.impl_abstract_pystub("torchao.ops"); - m.def("unpack_int4_packed(Tensor packed_w, int innerKTiles) -> Tensor"); + m.def("unpack_int4_to_int(Tensor packed_w, int innerKTiles) -> Tensor"); } diff --git a/torchao/ops.py b/torchao/ops.py index 6fa3d811b3..56b36f4b6c 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -46,7 +46,7 @@ def _(_in_feats, _weights, _scales, splitK = 1): -def unpack_int4_packed(packed_w: Tensor, innerKTiles: int) -> Tensor: +def unpack_int4_to_int(packed_w: Tensor, innerKTiles: int) -> Tensor: """ Unpacks weights that were packed with `torch.ops.aten._convert_weight_to_int4pack` to original tensor of shape `N x K`. @@ -60,12 +60,12 @@ def unpack_int4_packed(packed_w: Tensor, innerKTiles: int) -> Tensor: torch.tensor of shape is N x K, dtype is torch.int32 """ - return torch.ops.torchao.unpack_int4_packed.default( + return torch.ops.torchao.unpack_int4_to_int.default( packed_w=packed_w, innerKTiles=innerKTiles ) -@register_custom_op(f"torchao::unpack_int4_packed") +@register_custom_op(f"torchao::unpack_int4_to_int") def _(packed_w: Tensor, innerKTiles: int) -> Tensor: torch._check( packed_w.dim() == 4, From 3a3d788371824f606676046a9601455661ec0593 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 23 Jun 2024 15:29:57 -0700 Subject: [PATCH 05/19] add dequant --- test/test_ops.py | 85 +++++++++++++++- torchao/csrc/cuda/unpack_int4/unpack_int4.cu | 101 +++++++++++++++++++ torchao/csrc/unpack_int4.cpp | 2 + torchao/ops.py | 60 +++++++++++ 4 files changed, 243 insertions(+), 5 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index a3f2cba07d..bdd6b50738 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -8,6 +8,8 @@ from parameterized import parameterized import pytest +import torchao.quantization + try: import torchao.ops except RuntimeError: @@ -70,12 +72,13 @@ def test_fp6_llm_linear_correctness(self, BS, OC, IC, splitK): (14336, 4096), ] INNERKTILES = [2, 4, 8] - -TEST_CONFIGS = list(itertools.product(SHAPES, INNERKTILES)) +QGROUP_SIZES = [32, 64, 128, 256] +TEST_CONFIGS_UNPACK = list(itertools.product(SHAPES, INNERKTILES)) +TEST_CONFIGS_DEQUANT = list(itertools.product(SHAPES, INNERKTILES, QGROUP_SIZES)) @pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize("shape, innerKTiles", TEST_CONFIGS, ids=str) +@pytest.mark.parametrize("shape, innerKTiles", TEST_CONFIGS_UNPACK, ids=str) def test_int4_unpack_correctness(shape, innerKTiles): N, K = shape assert K % (innerKTiles * kTileSizeK) == 0 and N % kTileSizeN == 0 @@ -85,11 +88,10 @@ def test_int4_unpack_correctness(shape, innerKTiles): unpacked = torchao.ops.unpack_int4_to_int(packed_w, innerKTiles) assert torch.allclose(t, unpacked) - # TODO: Fix "test_aot_dispatch_dynamic" test failure @pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize("shape, innerKTiles", TEST_CONFIGS , ids=str) +@pytest.mark.parametrize("shape, innerKTiles", TEST_CONFIGS_UNPACK , ids=str) def test_int4_unpack_op(shape, innerKTiles): test_utils = [ "test_schema", @@ -106,6 +108,79 @@ def test_int4_unpack_op(shape, innerKTiles): test_utils=test_utils, ) +def dequant_ref(q, scales, zeros, group_size, dtype=torch.bfloat16): + n, k = q.shape + assert q.dtype == torch.int + + n_groups = k // group_size + assert scales.shape[0] == n and scales.shape[1] == n_groups + assert scales.shape == zeros.shape + + q_bf16 = q.to(dtype=dtype) + q_bf16 = q_bf16.reshape(-1, group_size) + dq = (q_bf16 - zeros.reshape(-1, 1)) * scales.reshape(-1, 1) + return dq.reshape(n, k) + +@pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("shape, innerKTiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) +def test_dequantize_int4_correctness(shape, innerKTiles, group_size): + n, k = shape + + # tinygemm params + nTileSize = 8 + kTileSize = 16 + nTiles = n // nTileSize + kTiles = k // (innerKTiles * kTileSize) + numThreads = 32 + + device = "cuda" + q = torch.randint(0, 16, shape, dtype=torch.int, device=device) + packed_w = torch._convert_weight_to_int4pack(q, innerKTiles) + # tinygemm params + assert packed_w.shape == torch.Size([nTiles, kTiles, numThreads, innerKTiles // 2]) + + # scales and zeros init + q_groups = k // group_size + scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device) + zeros = torch.randn_like(scales) + + scales_and_zeros = torchao.quantization.utils.pack_tinygemm_scales_and_zeros(scales, zeros) + assert scales_and_zeros.shape == torch.Size([q_groups, n, 2]) + scales_unpacked, zeros_unpacked = torchao.quantization.utils.unpack_tinygemm_scales_and_zeros(scales_and_zeros) + assert torch.allclose(scales_unpacked.reshape(scales.shape), scales) + assert torch.allclose(zeros_unpacked.reshape(zeros.shape), zeros) + + dq_ref = dequant_ref(q, scales, zeros, group_size) + dq = torchao.ops.dequantize_int4(packed_w, scales_and_zeros, group_size, innerKTiles) + assert torch.allclose(dq, dq_ref, atol=1e-4, rtol=1e-4) + +@pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("shape, innerKTiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) +def test_dequantize_int4_op(shape, innerKTiles, group_size): + n, k = shape + + device = "cuda" + q = torch.randint(0, 16, shape, dtype=torch.int, device=device) + packed_w = torch._convert_weight_to_int4pack(q, innerKTiles) + print(packed_w.shape) + q_groups = k // group_size + scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device) + zeros = torch.randn_like(scales) + scales_and_zeros = torchao.quantization.utils.pack_tinygemm_scales_and_zeros(scales, zeros) + + test_utils = [ + "test_schema", + "test_autograd_registration", + "test_faketensor", + # "test_aot_dispatch_dynamic", + ] + opcheck( + torch.ops.torchao.dequantize_int4, + (packed_w, scales_and_zeros, group_size, innerKTiles), + test_utils=test_utils, + ) if __name__ == "__main__": unittest.main() diff --git a/torchao/csrc/cuda/unpack_int4/unpack_int4.cu b/torchao/csrc/cuda/unpack_int4/unpack_int4.cu index 22cec81f76..301a08396f 100644 --- a/torchao/csrc/cuda/unpack_int4/unpack_int4.cu +++ b/torchao/csrc/cuda/unpack_int4/unpack_int4.cu @@ -110,6 +110,105 @@ __global__ void _dequantize_int4_kernel( } } +// output is [n][k] (int32 dtype) +// input is [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2] +// scales_and_zeros is [numQGroups][n][2] +// qGroupSize is 32, 64, 128 or 256 +at::Tensor _dequantize_int4( + const at::Tensor& packed_w, + const at::Tensor& scales_and_zeros, + int64_t group_size, + int64_t innerKTiles) +{ + + constexpr int32_t kNTileSize = 8; + constexpr int32_t kKTileSize = 16; + + c10::cuda::CUDAGuard g(packed_w.device()); + + // packed_w preconditions + TORCH_CHECK(packed_w.dim() == 4); + TORCH_CHECK(packed_w.dtype() == at::kInt); + TORCH_CHECK(packed_w.is_contiguous()); + TORCH_CHECK(packed_w.size(2) == 32); + TORCH_CHECK(packed_w.size(3) == innerKTiles / 2); + TORCH_CHECK(innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8); + + auto numQGroups = scales_and_zeros.size(0); + int N = packed_w.size(0) * kNTileSize; + int K = packed_w.size(1) * innerKTiles * kKTileSize; + + // scales_and_zeros preconditions + TORCH_CHECK( + group_size == 32 || group_size == 64 || group_size == 128 || + group_size == 256); + TORCH_CHECK(numQGroups == K / group_size); + TORCH_CHECK(scales_and_zeros.dim() == 3); + TORCH_CHECK(scales_and_zeros.size(1) == N); + TORCH_CHECK(scales_and_zeros.size(2) == 2); + + auto nTiles = divUp(N, kNTileSize); + auto kSuperTiles = divUp(K, innerKTiles * kKTileSize); + auto out = at::empty( + {N, K}, + at::TensorOptions().dtype(at::kBFloat16).device(packed_w.device())); + + auto stream = at::cuda::getCurrentCUDAStream(); + dim3 grid(kSuperTiles, nTiles); + +#define RUN_DEQUANT(QGROUPSIZE) \ + do { \ + switch(innerKTiles) { \ + case 2: \ + _dequantize_int4_kernel<<>>( \ + packed_w.packed_accessor32(), \ + out.packed_accessor32(), \ + scales_and_zeros.packed_accessor32()); \ + break; \ + case 4: \ + _dequantize_int4_kernel<<>>( \ + packed_w.packed_accessor32(), \ + out.packed_accessor32(), \ + scales_and_zeros.packed_accessor32()); \ + break; \ + case 8: \ + _dequantize_int4_kernel<<>>( \ + packed_w.packed_accessor32(), \ + out.packed_accessor32(), \ + scales_and_zeros.packed_accessor32()); \ + break; \ + default: \ + break; \ + } \ + } while(false) + +#define DISPATCH_Q_GROUP() \ + do { \ + switch (group_size) { \ + case 32: \ + RUN_DEQUANT(32); \ + break; \ + case 64: \ + RUN_DEQUANT(64); \ + break; \ + case 128: \ + RUN_DEQUANT(128); \ + break; \ + case 256: \ + RUN_DEQUANT(256); \ + break; \ + default: \ + break; \ + } \ + } while(false) + + DISPATCH_Q_GROUP(); + #undef DISPATCH_Q_GROUP + #undef RUN_DEQUANT + + return out; +} + // output is [n][k] (int32 dtype) // input is [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2] at::Tensor _unpack_int4_to_int( @@ -164,4 +263,6 @@ at::Tensor _unpack_int4_to_int( TORCH_LIBRARY_IMPL(torchao, CUDA, m) { m.impl("torchao::unpack_int4_to_int", &_unpack_int4_to_int); + m.impl("torchao::dequantize_int4", &_dequantize_int4); + } diff --git a/torchao/csrc/unpack_int4.cpp b/torchao/csrc/unpack_int4.cpp index 1c427dc51a..87a8c176a4 100644 --- a/torchao/csrc/unpack_int4.cpp +++ b/torchao/csrc/unpack_int4.cpp @@ -5,4 +5,6 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.impl_abstract_pystub("torchao.ops"); m.def("unpack_int4_to_int(Tensor packed_w, int innerKTiles) -> Tensor"); + m.def("dequantize_int4(Tensor packed_w, Tensor scales_and_zeros, int group_size, int innerKTiles) -> Tensor"); + } diff --git a/torchao/ops.py b/torchao/ops.py index 56b36f4b6c..793cb93a97 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -88,3 +88,63 @@ def _(packed_w: Tensor, innerKTiles: int) -> Tensor: K = packed_w.size(1) * innerKTiles * 16 return torch.empty((N, K), dtype=torch.int32, device=packed_w.device) + +def dequantize_int4(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, innerKTiles: int) -> Tensor: + """ + Dequantizes by: + - Unpacking weights that were packed with `torch.ops.aten._convert_weight_to_int4pack` to original tensor of shape `N x K` + - Upcasting to bfloat16 + - Dequantizing with the scales_and_zeros that were packed with `torchao.quantization.utils.pack_tinygemm_scales_and_zeros` + + Assumes: + - packed weights were generated with `torch.ops.aten._convert_weight_to_int4pack` with `innerKTiles = 2 | 4 | 8`" + - packed scales_and_zeros were generated with `torchao.quantization.utils.pack_tinygemm_scales_and_zeros` + - qGroupSize is 32 | 64 | 128 | 256 + + Args: + packed_w: torch.tensor: 4D tensor with shape `(N / 8) x (K / (innerKTiles * 16)) x 32 x innerKTiles / 2`, dtype is torch.int32 + scales_and_zeros: torch.tensor: 3D tensor with shape `numQGroups x N x 2`, dtype is torch.bfloat16 where numQGroups is K / qGroupSize + qGroupSize: int + innerKTiles: int + + Returns: + torch.tensor of shape is N x K, dtype is torch.bfloat16 + + """ + return torch.ops.torchao.dequantize_int4.default( + packed_w, scales_and_zeros, group_size, innerKTiles + ) + + +@register_custom_op(f"torchao::dequantize_int4") +def _(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, innerKTiles: int) -> Tensor: + # packed_w preconditions + torch._check( + packed_w.dim() == 4, + lambda: f"packed weight should be a 4d tensor, got {packed_w.dim()}D", + ) + torch._check( + packed_w.dtype is torch.int32, + lambda: f"weight must be INT32, got {packed_w.dtype}", + ) + torch._check( + innerKTiles == 2 or innerKTiles == 4 or innerKTiles == 8, + lambda: "innerKTiles must be 2, 4, or 8", + ) + torch._check(packed_w.size(2) == 32, lambda: "packed weight must have 32 at dim 2") + torch._check( + packed_w.size(3) == innerKTiles / 2, + lambda: "packed weight must have innerKTiles/2 at dim 3", + ) + N = packed_w.size(0) * 8 + K = packed_w.size(1) * innerKTiles * 16 + + # scales_and_zeros preconditions + torch._check(scales_and_zeros.dtype is torch.bfloat16, lambda: "scales_and_zeros must be bfloat16") + torch._check(scales_and_zeros.dim() == 3, lambda: "scales_and_zeros must be 3D, got {scales_and_zeros.dim()}") + torch._check(group_size == 32 or group_size == 64 or group_size == 128 or group_size == 256, lambda: "qGroupSize must be 32, 64, 128, or 256") + torch._check(scales_and_zeros.size(0) == K // group_size, lambda: "scales_and_zeros must have K // qGroupSize at dim 0") + torch._check(scales_and_zeros.size(1) == N, lambda: "scales_and_zeros must have N at dim 1") + torch._check(scales_and_zeros.size(2) == 2, lambda: "scales_and_zeros must have 2 at dim 2") + + return torch.empty((N, K), dtype=torch.bfloat16, device=packed_w.device) From a2ca1498a9e6d852b5e31ac46d9f95f2f8db07c8 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 23 Jun 2024 15:36:26 -0700 Subject: [PATCH 06/19] add additional dequant check --- test/test_ops.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index bdd6b50738..fe567711ae 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -154,7 +154,19 @@ def test_dequantize_int4_correctness(shape, innerKTiles, group_size): dq_ref = dequant_ref(q, scales, zeros, group_size) dq = torchao.ops.dequantize_int4(packed_w, scales_and_zeros, group_size, innerKTiles) assert torch.allclose(dq, dq_ref, atol=1e-4, rtol=1e-4) - + + # TODO: Figure out why this fails + # This is how torchao.dtypes.affine_quantized_tensor recovers the original tensor + # https://github.com/pytorch/ao/blob/9dc2c118f59ad4135a8c39166c4ceebda73c62a9/torchao/dtypes/affine_quantized_tensor.py#L505 + # a_eye = torch.eye(k, device=device, dtype=torch.bfloat16) + # dq_check = torch.ops.aten._weight_int4pack_mm( + # a_eye, + # packed_w, + # group_size, + # scales_and_zeros, + # ).t() + # assert torch.allclose(dq, dq_check, atol=1e-4, rtol=1e-4) + @pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("shape, innerKTiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) From 052d4825dd52fcd02d4fd87ccdbca58d800bd14f Mon Sep 17 00:00:00 2001 From: jeromeku Date: Thu, 27 Jun 2024 00:34:45 +0000 Subject: [PATCH 07/19] update tinygemm dequantize test --- test/test_ops.py | 71 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 50 insertions(+), 21 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index fe567711ae..46c0af1708 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -7,7 +7,13 @@ import unittest from parameterized import parameterized import pytest - +from torchao.quantization.utils import ( + get_groupwise_affine_qparams, + groupwise_affine_quantize_tensor_from_qparams, + groupwise_affine_dequantize_tensor_from_qparams, + pack_tinygemm_scales_and_zeros, + unpack_tinygemm_scales_and_zeros +) import torchao.quantization try: @@ -108,7 +114,7 @@ def test_int4_unpack_op(shape, innerKTiles): test_utils=test_utils, ) -def dequant_ref(q, scales, zeros, group_size, dtype=torch.bfloat16): +def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16): n, k = q.shape assert q.dtype == torch.int @@ -116,16 +122,24 @@ def dequant_ref(q, scales, zeros, group_size, dtype=torch.bfloat16): assert scales.shape[0] == n and scales.shape[1] == n_groups assert scales.shape == zeros.shape - q_bf16 = q.to(dtype=dtype) - q_bf16 = q_bf16.reshape(-1, group_size) - dq = (q_bf16 - zeros.reshape(-1, 1)) * scales.reshape(-1, 1) + midpoint = 2 ** (nbits - 1) + + #Convert fron u4 -> s4 and upcast to bfloat16 + q = q.sub(midpoint).to(dtype) + + # Dequantize + q = q.reshape(-1, group_size) + dq = q * scales.reshape(-1, 1) + zeros.reshape(-1, 1) + return dq.reshape(n, k) + @pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize("shape, innerKTiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) +@pytest.mark.parametrize("shape, innerKTiles, group_size", TEST_CONFIGS_DEQUANT[:1], ids=str) def test_dequantize_int4_correctness(shape, innerKTiles, group_size): n, k = shape + dtype = torch.bfloat16 # tinygemm params nTileSize = 8 @@ -135,25 +149,40 @@ def test_dequantize_int4_correctness(shape, innerKTiles, group_size): numThreads = 32 device = "cuda" - q = torch.randint(0, 16, shape, dtype=torch.int, device=device) - packed_w = torch._convert_weight_to_int4pack(q, innerKTiles) - # tinygemm params - assert packed_w.shape == torch.Size([nTiles, kTiles, numThreads, innerKTiles // 2]) - # scales and zeros init - q_groups = k // group_size - scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device) - zeros = torch.randn_like(scales) + t = torch.randn(n, k, dtype=dtype, device=device) + scales, zeros = get_groupwise_affine_qparams(t, n_bit=4, groupsize=group_size, dtype=dtype) - scales_and_zeros = torchao.quantization.utils.pack_tinygemm_scales_and_zeros(scales, zeros) + # Quantize + q = groupwise_affine_quantize_tensor_from_qparams( + t, scales, zeros, n_bit=4, groupsize=group_size + ) + + # Pack to tensor core layout + packed = torch.ops.aten._convert_weight_to_int4pack(q, innerKTiles) + scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) + q_groups = k // group_size assert scales_and_zeros.shape == torch.Size([q_groups, n, 2]) - scales_unpacked, zeros_unpacked = torchao.quantization.utils.unpack_tinygemm_scales_and_zeros(scales_and_zeros) - assert torch.allclose(scales_unpacked.reshape(scales.shape), scales) - assert torch.allclose(zeros_unpacked.reshape(zeros.shape), zeros) + + dq_ao = groupwise_affine_dequantize_tensor_from_qparams( + q, scales, zeros, n_bit=4, groupsize=group_size + ) dq_ref = dequant_ref(q, scales, zeros, group_size) - dq = torchao.ops.dequantize_int4(packed_w, scales_and_zeros, group_size, innerKTiles) - assert torch.allclose(dq, dq_ref, atol=1e-4, rtol=1e-4) + + print((dq_ao - dq_ref).abs().max()) + + # test dequant using identity mat + a_eye = torch.eye(k, device=device, dtype=dtype) + dq_check = torch.ops.aten._weight_int4pack_mm( + a_eye, + packed, + group_size, + scales_and_zeros, + ).t() + print((dq_check - dq_ref).abs().max()) + print((dq_check - dq_ao).abs().max()) + # TODO: Figure out why this fails # This is how torchao.dtypes.affine_quantized_tensor recovers the original tensor @@ -172,8 +201,8 @@ def test_dequantize_int4_correctness(shape, innerKTiles, group_size): @pytest.mark.parametrize("shape, innerKTiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) def test_dequantize_int4_op(shape, innerKTiles, group_size): n, k = shape - device = "cuda" + q = torch.randint(0, 16, shape, dtype=torch.int, device=device) packed_w = torch._convert_weight_to_int4pack(q, innerKTiles) print(packed_w.shape) From 18c505fee01dfc1c770c86faf5bf4f47b0ebb28d Mon Sep 17 00:00:00 2001 From: jeromeku Date: Thu, 27 Jun 2024 01:03:54 +0000 Subject: [PATCH 08/19] correct dequant kernel logic --- test/test_ops.py | 13 ++-- torchao/csrc/cuda/unpack_int4/unpack_int4.cu | 74 ++++++++++++++++---- 2 files changed, 68 insertions(+), 19 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 46c0af1708..e73b95c6d2 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -169,21 +169,22 @@ def test_dequantize_int4_correctness(shape, innerKTiles, group_size): ) dq_ref = dequant_ref(q, scales, zeros, group_size) - - print((dq_ao - dq_ref).abs().max()) + print() + print(f"dq_ao - dq_ref: {(dq_ao - dq_ref).abs().max()}") # test dequant using identity mat a_eye = torch.eye(k, device=device, dtype=dtype) - dq_check = torch.ops.aten._weight_int4pack_mm( + dq_id = torch.ops.aten._weight_int4pack_mm( a_eye, packed, group_size, scales_and_zeros, ).t() - print((dq_check - dq_ref).abs().max()) - print((dq_check - dq_ao).abs().max()) - + print(f"dq_id - dq_ao: {(dq_id - dq_ao).abs().max()}") + dq_test = torchao.ops.dequantize_int4(packed, scales_and_zeros, group_size, innerKTiles) + print(f"dq_test - dq_ao: {(dq_test - dq_ao).abs().max()}") + print(f"dq_test - dq_id: {(dq_test - dq_id).abs().max()}") # TODO: Figure out why this fails # This is how torchao.dtypes.affine_quantized_tensor recovers the original tensor # https://github.com/pytorch/ao/blob/9dc2c118f59ad4135a8c39166c4ceebda73c62a9/torchao/dtypes/affine_quantized_tensor.py#L505 diff --git a/torchao/csrc/cuda/unpack_int4/unpack_int4.cu b/torchao/csrc/cuda/unpack_int4/unpack_int4.cu index 301a08396f..a6604277a4 100644 --- a/torchao/csrc/cuda/unpack_int4/unpack_int4.cu +++ b/torchao/csrc/cuda/unpack_int4/unpack_int4.cu @@ -12,7 +12,54 @@ constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) { return blocks; } constexpr int32_t kWarpSize = 32; +struct __align__(16) bf16x2x4 { + __nv_bfloat162 vals[4]; +}; + +inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) { + bf16x2x4 result; + constexpr int kElements = 8; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const source_i4s = source; + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; + + // We don't have enough mantissa to remove as much shift overhead as FP16, so + // we must loop. No shift needed for first item. + uint32_t i4s = source_i4s; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); +#pragma unroll + for (int ii = 1; ii < kElements / 2; ++ii) { + i4s >>= 4; // or is it 8? + // (i4s & 0x000f000f) | 0x43004300 + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[ii]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + } + // This is the BF16 {-136, -136} represented as an integer. + static constexpr uint32_t BF16_BIAS = 0xC308C308; + static constexpr uint32_t BF16_ONE = 0x3F803F80; + +// Finally, we construct the output numbers. +#pragma unroll + for (int ii = 0; ii < kElements / 2; ++ii) { + // Since this section is for Ampere+, we use bf16 fma to do the bias + // subtraction + asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[ii]) + : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); + } + + return result; +} // in size [ceil(n / 8)][ceil(k / (InnerKTiles * 16))][32][InnerKTiles / 2] // // out size [n][k] @@ -62,30 +109,31 @@ __global__ void _dequantize_int4_kernel( if constexpr(kDequant) { // static_assert(scales_and_zeros.has_value(), "scales_and_zeros must be set when dequantizing"); static_assert(std::is_same::value, "Out must be BFloat16 when dequantizing"); - __nv_bfloat16 v[8]; - - v[0] = __int2bfloat16_rn(pack & 0x0000000f); - v[2] = __int2bfloat16_rn((pack >> 4) & 0x0000000f); - v[4] = __int2bfloat16_rn((pack >> 8) & 0x0000000f); - v[6] = __int2bfloat16_rn((pack >> 12) & 0x0000000f); - v[1] = __int2bfloat16_rn((pack >> 16) & 0x0000000f); - v[3] = __int2bfloat16_rn((pack >> 20) & 0x0000000f); - v[5] = __int2bfloat16_rn((pack >> 24) & 0x0000000f); - v[7] = __int2bfloat16_rn((pack >> 28) & 0x0000000f); - + // __nv_bfloat16 v[8]; + + // // Extract u4, convert to s4 by subtracting by 2 ** nbits / 2, then convert to bfloat16 + // v[0] = __int2bfloat16_rn(pack & 0x0000000f - 8); + // v[2] = __int2bfloat16_rn((pack >> 4) & 0x0000000f - 8); + // v[4] = __int2bfloat16_rn((pack >> 8) & 0x0000000f - 8); + // v[6] = __int2bfloat16_rn((pack >> 12) & 0x0000000f - 8); + // v[1] = __int2bfloat16_rn((pack >> 16) & 0x0000000f - 8); + // v[3] = __int2bfloat16_rn((pack >> 20) & 0x0000000f - 8); + // v[5] = __int2bfloat16_rn((pack >> 24) & 0x0000000f - 8); + // v[7] = __int2bfloat16_rn((pack >> 28) & 0x0000000f - 8); + bf16x2x4 v_bf16x2 = convert_i4x8_to_bf16x2x4(pack); // All b values within a 16x16 tile should fall within the same q group // Hence we load 1 scale and zero per loop int qgroup = ks[0] / groupSize; const __nv_bfloat16 *pSZ = reinterpret_cast(&scales_and_zeros.value()[qgroup][n0][0]); //Reinterpret as pairs of v as pairs of bfloat16 - __nv_bfloat162 *v_bf16x2 = reinterpret_cast<__nv_bfloat162*>(v); + // __nv_bfloat162 *v_bf16x2 = reinterpret_cast<__nv_bfloat162*>(v); __nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]); __nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]); #pragma unroll for (int i = 0; i < 4; i++) { - reinterpret_cast<__nv_bfloat162*>(&pOut[ks[i]])[0] = __hmul2(scale2, __hsub2(v_bf16x2[i], zero2));; + reinterpret_cast<__nv_bfloat162*>(&pOut[ks[i]])[0] = __hfma2(v_bf16x2.vals[i], scale2, zero2); } } else { From d831a5eb71b6ebda57498b70ee49b7aa5189b51d Mon Sep 17 00:00:00 2001 From: jeromeku Date: Thu, 27 Jun 2024 01:10:12 +0000 Subject: [PATCH 09/19] clean up kernel --- torchao/csrc/cuda/unpack_int4/unpack_int4.cu | 22 ++++++++------------ 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/torchao/csrc/cuda/unpack_int4/unpack_int4.cu b/torchao/csrc/cuda/unpack_int4/unpack_int4.cu index a6604277a4..18724c6cc5 100644 --- a/torchao/csrc/cuda/unpack_int4/unpack_int4.cu +++ b/torchao/csrc/cuda/unpack_int4/unpack_int4.cu @@ -12,10 +12,14 @@ constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) { return blocks; } constexpr int32_t kWarpSize = 32; + +//Simple data structure to represent 4 pairs of bfloat16s, used for vectorized dequantization +//https://github.com/pytorch/pytorch/blob/b6689e0fb83a1578959ab0d9c6d2d9e11f7df21a/aten/src/ATen/native/cuda/int4mm.cu#L178-L180 struct __align__(16) bf16x2x4 { __nv_bfloat162 vals[4]; }; +//Copied from https://github.com/pytorch/pytorch/blob/b6689e0fb83a1578959ab0d9c6d2d9e11f7df21a/aten/src/ATen/native/cuda/int4mm.cu#L195C1-L241C1 inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) { bf16x2x4 result; constexpr int kElements = 8; @@ -61,7 +65,7 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) { return result; } // in size [ceil(n / 8)][ceil(k / (InnerKTiles * 16))][32][InnerKTiles / 2] -// +// scales_and_zeros size [numQGroups][n][2] // out size [n][k] template __global__ void _dequantize_int4_kernel( @@ -112,28 +116,20 @@ __global__ void _dequantize_int4_kernel( // __nv_bfloat16 v[8]; // // Extract u4, convert to s4 by subtracting by 2 ** nbits / 2, then convert to bfloat16 - // v[0] = __int2bfloat16_rn(pack & 0x0000000f - 8); - // v[2] = __int2bfloat16_rn((pack >> 4) & 0x0000000f - 8); - // v[4] = __int2bfloat16_rn((pack >> 8) & 0x0000000f - 8); - // v[6] = __int2bfloat16_rn((pack >> 12) & 0x0000000f - 8); - // v[1] = __int2bfloat16_rn((pack >> 16) & 0x0000000f - 8); - // v[3] = __int2bfloat16_rn((pack >> 20) & 0x0000000f - 8); - // v[5] = __int2bfloat16_rn((pack >> 24) & 0x0000000f - 8); - // v[7] = __int2bfloat16_rn((pack >> 28) & 0x0000000f - 8); - bf16x2x4 v_bf16x2 = convert_i4x8_to_bf16x2x4(pack); + bf16x2x4 v_bf16x2x4 = convert_i4x8_to_bf16x2x4(pack); + // All b values within a 16x16 tile should fall within the same q group // Hence we load 1 scale and zero per loop int qgroup = ks[0] / groupSize; const __nv_bfloat16 *pSZ = reinterpret_cast(&scales_and_zeros.value()[qgroup][n0][0]); - //Reinterpret as pairs of v as pairs of bfloat16 - // __nv_bfloat162 *v_bf16x2 = reinterpret_cast<__nv_bfloat162*>(v); + // Vectorize scales and zeros __nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]); __nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]); #pragma unroll for (int i = 0; i < 4; i++) { - reinterpret_cast<__nv_bfloat162*>(&pOut[ks[i]])[0] = __hfma2(v_bf16x2.vals[i], scale2, zero2); + reinterpret_cast<__nv_bfloat162*>(&pOut[ks[i]])[0] = __hfma2(v_bf16x2x4.vals[i], scale2, zero2); } } else { From 48a80622989846cc56fe18becdc70253b3dd3b09 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Fri, 28 Jun 2024 00:32:18 +0000 Subject: [PATCH 10/19] update dequantize kernel tests --- test/test_ops.py | 41 ++++++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index e73b95c6d2..789074f919 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -136,7 +136,7 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16): @pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize("shape, innerKTiles, group_size", TEST_CONFIGS_DEQUANT[:1], ids=str) +@pytest.mark.parametrize("shape, innerKTiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) def test_dequantize_int4_correctness(shape, innerKTiles, group_size): n, k = shape dtype = torch.bfloat16 @@ -167,12 +167,8 @@ def test_dequantize_int4_correctness(shape, innerKTiles, group_size): dq_ao = groupwise_affine_dequantize_tensor_from_qparams( q, scales, zeros, n_bit=4, groupsize=group_size ) - - dq_ref = dequant_ref(q, scales, zeros, group_size) - print() - print(f"dq_ao - dq_ref: {(dq_ao - dq_ref).abs().max()}") - # test dequant using identity mat + # Dequantize by passing in an identity matrix as the activation a_eye = torch.eye(k, device=device, dtype=dtype) dq_id = torch.ops.aten._weight_int4pack_mm( a_eye, @@ -180,22 +176,25 @@ def test_dequantize_int4_correctness(shape, innerKTiles, group_size): group_size, scales_and_zeros, ).t() - print(f"dq_id - dq_ao: {(dq_id - dq_ao).abs().max()}") - dq_test = torchao.ops.dequantize_int4(packed, scales_and_zeros, group_size, innerKTiles) - print(f"dq_test - dq_ao: {(dq_test - dq_ao).abs().max()}") - print(f"dq_test - dq_id: {(dq_test - dq_id).abs().max()}") - # TODO: Figure out why this fails - # This is how torchao.dtypes.affine_quantized_tensor recovers the original tensor - # https://github.com/pytorch/ao/blob/9dc2c118f59ad4135a8c39166c4ceebda73c62a9/torchao/dtypes/affine_quantized_tensor.py#L505 - # a_eye = torch.eye(k, device=device, dtype=torch.bfloat16) - # dq_check = torch.ops.aten._weight_int4pack_mm( - # a_eye, - # packed_w, - # group_size, - # scales_and_zeros, - # ).t() - # assert torch.allclose(dq, dq_check, atol=1e-4, rtol=1e-4) + # Actual operation to test + dq_op = torchao.ops.dequantize_int4(packed, scales_and_zeros, group_size, innerKTiles) + + + # Compare results + diff_ao_id = (dq_id - dq_ao).abs().max() + diff_op_id = (dq_op - dq_id).abs().max() + diff_op_ao = (dq_op - dq_ao).abs().max() + + # There are slight numerical differences when dequantizing with an identity matrix + # Since the `dequantize_int4` kernel relies on same underlying numerical conversions, this gives same + # numerical differences when compared to the `groupwise_affine_dequantize` + + # Test that the `dequant` kernel gives same results as identity matrix-based dequant + assert diff_op_id == 0 + + # Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix + assert diff_op_ao == diff_ao_id @pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") From 279b79a6bbd209b5a3e9e00162b20955231c2ffa Mon Sep 17 00:00:00 2001 From: jeromeku Date: Fri, 28 Jun 2024 00:55:17 +0000 Subject: [PATCH 11/19] rename kernel ops to tensor_core_tiled_layout --- test/test_ops.py | 25 +- torchao/csrc/cuda/unpack_int4/unpack_int4.cu | 312 ------------------- torchao/csrc/unpack_int4.cpp | 10 - torchao/ops.py | 14 +- 4 files changed, 19 insertions(+), 342 deletions(-) delete mode 100644 torchao/csrc/cuda/unpack_int4/unpack_int4.cu delete mode 100644 torchao/csrc/unpack_int4.cpp diff --git a/test/test_ops.py b/test/test_ops.py index 789074f919..adc4572f09 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -82,23 +82,23 @@ def test_fp6_llm_linear_correctness(self, BS, OC, IC, splitK): TEST_CONFIGS_UNPACK = list(itertools.product(SHAPES, INNERKTILES)) TEST_CONFIGS_DEQUANT = list(itertools.product(SHAPES, INNERKTILES, QGROUP_SIZES)) -@pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") @pytest.mark.parametrize("shape, innerKTiles", TEST_CONFIGS_UNPACK, ids=str) -def test_int4_unpack_correctness(shape, innerKTiles): +def test_unpack_tensor_core_tiled_layout_correctness(shape, innerKTiles): N, K = shape assert K % (innerKTiles * kTileSizeK) == 0 and N % kTileSizeN == 0 t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") packed_w = torch.ops.aten._convert_weight_to_int4pack(t, innerKTiles) - unpacked = torchao.ops.unpack_int4_to_int(packed_w, innerKTiles) + unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed_w, innerKTiles) assert torch.allclose(t, unpacked) # TODO: Fix "test_aot_dispatch_dynamic" test failure -@pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") @pytest.mark.parametrize("shape, innerKTiles", TEST_CONFIGS_UNPACK , ids=str) -def test_int4_unpack_op(shape, innerKTiles): +def test_unpack_tensor_core_tiled_layout_op(shape, innerKTiles): test_utils = [ "test_schema", "test_autograd_registration", @@ -109,7 +109,7 @@ def test_int4_unpack_op(shape, innerKTiles): packed_w = torch.ops.aten._convert_weight_to_int4pack(t, innerKTiles) opcheck( - torch.ops.torchao.unpack_int4_to_int, + torch.ops.torchao.unpack_tensor_core_tiled_layout, (packed_w, innerKTiles), test_utils=test_utils, ) @@ -134,10 +134,10 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16): return dq.reshape(n, k) -@pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") @pytest.mark.parametrize("shape, innerKTiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) -def test_dequantize_int4_correctness(shape, innerKTiles, group_size): +def test_dequantize_tensor_core_tiled_layout_correctness(shape, innerKTiles, group_size): n, k = shape dtype = torch.bfloat16 @@ -178,9 +178,8 @@ def test_dequantize_int4_correctness(shape, innerKTiles, group_size): ).t() # Actual operation to test - dq_op = torchao.ops.dequantize_int4(packed, scales_and_zeros, group_size, innerKTiles) + dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, innerKTiles) - # Compare results diff_ao_id = (dq_id - dq_ao).abs().max() diff_op_id = (dq_op - dq_id).abs().max() @@ -196,10 +195,10 @@ def test_dequantize_int4_correctness(shape, innerKTiles, group_size): # Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix assert diff_op_ao == diff_ao_id -@pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") @pytest.mark.parametrize("shape, innerKTiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) -def test_dequantize_int4_op(shape, innerKTiles, group_size): +def test_dequantize_tensor_core_tiled_layout_op(shape, innerKTiles, group_size): n, k = shape device = "cuda" @@ -218,7 +217,7 @@ def test_dequantize_int4_op(shape, innerKTiles, group_size): # "test_aot_dispatch_dynamic", ] opcheck( - torch.ops.torchao.dequantize_int4, + torch.ops.torchao.dequantize_tensor_core_tiled_layout, (packed_w, scales_and_zeros, group_size, innerKTiles), test_utils=test_utils, ) diff --git a/torchao/csrc/cuda/unpack_int4/unpack_int4.cu b/torchao/csrc/cuda/unpack_int4/unpack_int4.cu deleted file mode 100644 index 18724c6cc5..0000000000 --- a/torchao/csrc/cuda/unpack_int4/unpack_int4.cu +++ /dev/null @@ -1,312 +0,0 @@ -#include -#include -#include -#include -#include -#include - -template -constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) { - static_assert(std::is_integral::value && std::is_integral::value, ""); - const uint64_t blocks = a / b + (a % b != 0); - return blocks; -} -constexpr int32_t kWarpSize = 32; - -//Simple data structure to represent 4 pairs of bfloat16s, used for vectorized dequantization -//https://github.com/pytorch/pytorch/blob/b6689e0fb83a1578959ab0d9c6d2d9e11f7df21a/aten/src/ATen/native/cuda/int4mm.cu#L178-L180 -struct __align__(16) bf16x2x4 { - __nv_bfloat162 vals[4]; -}; - -//Copied from https://github.com/pytorch/pytorch/blob/b6689e0fb83a1578959ab0d9c6d2d9e11f7df21a/aten/src/ATen/native/cuda/int4mm.cu#L195C1-L241C1 -inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) { - bf16x2x4 result; - constexpr int kElements = 8; - - uint32_t* h = reinterpret_cast(&result); - uint32_t const source_i4s = source; - - // First, we extract the i4s and construct an intermediate fp16 number. - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint32_t MASK = 0x000f000f; - static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; - - // We don't have enough mantissa to remove as much shift overhead as FP16, so - // we must loop. No shift needed for first item. - uint32_t i4s = source_i4s; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[0]) - : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); -#pragma unroll - for (int ii = 1; ii < kElements / 2; ++ii) { - i4s >>= 4; // or is it 8? - // (i4s & 0x000f000f) | 0x43004300 - asm volatile( - "lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[ii]) - : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); - } - - // This is the BF16 {-136, -136} represented as an integer. - static constexpr uint32_t BF16_BIAS = 0xC308C308; - static constexpr uint32_t BF16_ONE = 0x3F803F80; - -// Finally, we construct the output numbers. -#pragma unroll - for (int ii = 0; ii < kElements / 2; ++ii) { - // Since this section is for Ampere+, we use bf16 fma to do the bias - // subtraction - asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" - : "=r"(h[ii]) - : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); - } - - return result; -} -// in size [ceil(n / 8)][ceil(k / (InnerKTiles * 16))][32][InnerKTiles / 2] -// scales_and_zeros size [numQGroups][n][2] -// out size [n][k] -template -__global__ void _dequantize_int4_kernel( - const at::PackedTensorAccessor32 in, - at::PackedTensorAccessor32 out, - at::optional> scales_and_zeros = c10::nullopt) -{ - - constexpr int32_t kNTileSize = 8; - constexpr int32_t kKTileSize = 16; - - auto kOuterTile = blockIdx.x; - auto nTile = blockIdx.y; - auto t = threadIdx.x; - - // n dimension that this lane loads from - auto n0 = nTile * kNTileSize + (t / 4); - - // 8 k-tile values, 4 per m16n8k16 mma.sync operand B - // int32_t ks[8]; - //Only need 4 offsets since TC layout for single tile is 2x2 (2 pairs of 2 contiguous values) - int32_t ks[4]; - - // Store address base offset - auto pOut = &out[n0][0]; - -// Unpack 2 k-tiles at a time since min pack size is InnerKTiles = 2 -#pragma unroll - for (int innerKTile = 0; innerKTile < InnerKTiles; innerKTile += 2) { - //Tensor-core layout for m16n8k16 is such that each tile has 2 pairs of 2 contiguous values - //Hence, we only need 4 offsets - // Offsets of innerTile0 - auto kBase0 = (kOuterTile * InnerKTiles + innerKTile) * kKTileSize; - ks[0] = kBase0 + (t % 4) * 2; - ks[1] = ks[0] + 8; - - // Offsets of innerTile1 - auto kBase1 = kBase0 + kKTileSize; - ks[2] = kBase1 + (t % 4) * 2; - ks[3] = ks[2] + 8; - - // inner k-tiles unpack two at a time - int32_t pack = in[nTile][kOuterTile][t][innerKTile / 2]; - - if constexpr(kDequant) { - // static_assert(scales_and_zeros.has_value(), "scales_and_zeros must be set when dequantizing"); - static_assert(std::is_same::value, "Out must be BFloat16 when dequantizing"); - // __nv_bfloat16 v[8]; - - // // Extract u4, convert to s4 by subtracting by 2 ** nbits / 2, then convert to bfloat16 - bf16x2x4 v_bf16x2x4 = convert_i4x8_to_bf16x2x4(pack); - - // All b values within a 16x16 tile should fall within the same q group - // Hence we load 1 scale and zero per loop - int qgroup = ks[0] / groupSize; - const __nv_bfloat16 *pSZ = reinterpret_cast(&scales_and_zeros.value()[qgroup][n0][0]); - - // Vectorize scales and zeros - __nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]); - __nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]); - - #pragma unroll - for (int i = 0; i < 4; i++) { - reinterpret_cast<__nv_bfloat162*>(&pOut[ks[i]])[0] = __hfma2(v_bf16x2x4.vals[i], scale2, zero2); - } - } - else { - static_assert(std::is_same::value, "Out must be int32_t when unpacking to int"); - int32_t v[8]; - - v[0] = pack & 0x0000000f; - v[2] = (pack >> 4) & 0x0000000f; - v[4] = (pack >> 8) & 0x0000000f; - v[6] = (pack >> 12) & 0x0000000f; - v[1] = (pack >> 16) & 0x0000000f; - v[3] = (pack >> 20) & 0x0000000f; - v[5] = (pack >> 24) & 0x0000000f; - v[7] = (pack >> 28) & 0x0000000f; - int2* v_i32x2 = reinterpret_cast(v); - - #pragma unroll - for (int i = 0; i < 4; ++i) { - reinterpret_cast(&pOut[ks[i]])[0] = v_i32x2[i]; - } - } - } -} - -// output is [n][k] (int32 dtype) -// input is [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2] -// scales_and_zeros is [numQGroups][n][2] -// qGroupSize is 32, 64, 128 or 256 -at::Tensor _dequantize_int4( - const at::Tensor& packed_w, - const at::Tensor& scales_and_zeros, - int64_t group_size, - int64_t innerKTiles) -{ - - constexpr int32_t kNTileSize = 8; - constexpr int32_t kKTileSize = 16; - - c10::cuda::CUDAGuard g(packed_w.device()); - - // packed_w preconditions - TORCH_CHECK(packed_w.dim() == 4); - TORCH_CHECK(packed_w.dtype() == at::kInt); - TORCH_CHECK(packed_w.is_contiguous()); - TORCH_CHECK(packed_w.size(2) == 32); - TORCH_CHECK(packed_w.size(3) == innerKTiles / 2); - TORCH_CHECK(innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8); - - auto numQGroups = scales_and_zeros.size(0); - int N = packed_w.size(0) * kNTileSize; - int K = packed_w.size(1) * innerKTiles * kKTileSize; - - // scales_and_zeros preconditions - TORCH_CHECK( - group_size == 32 || group_size == 64 || group_size == 128 || - group_size == 256); - TORCH_CHECK(numQGroups == K / group_size); - TORCH_CHECK(scales_and_zeros.dim() == 3); - TORCH_CHECK(scales_and_zeros.size(1) == N); - TORCH_CHECK(scales_and_zeros.size(2) == 2); - - auto nTiles = divUp(N, kNTileSize); - auto kSuperTiles = divUp(K, innerKTiles * kKTileSize); - auto out = at::empty( - {N, K}, - at::TensorOptions().dtype(at::kBFloat16).device(packed_w.device())); - - auto stream = at::cuda::getCurrentCUDAStream(); - dim3 grid(kSuperTiles, nTiles); - -#define RUN_DEQUANT(QGROUPSIZE) \ - do { \ - switch(innerKTiles) { \ - case 2: \ - _dequantize_int4_kernel<<>>( \ - packed_w.packed_accessor32(), \ - out.packed_accessor32(), \ - scales_and_zeros.packed_accessor32()); \ - break; \ - case 4: \ - _dequantize_int4_kernel<<>>( \ - packed_w.packed_accessor32(), \ - out.packed_accessor32(), \ - scales_and_zeros.packed_accessor32()); \ - break; \ - case 8: \ - _dequantize_int4_kernel<<>>( \ - packed_w.packed_accessor32(), \ - out.packed_accessor32(), \ - scales_and_zeros.packed_accessor32()); \ - break; \ - default: \ - break; \ - } \ - } while(false) - -#define DISPATCH_Q_GROUP() \ - do { \ - switch (group_size) { \ - case 32: \ - RUN_DEQUANT(32); \ - break; \ - case 64: \ - RUN_DEQUANT(64); \ - break; \ - case 128: \ - RUN_DEQUANT(128); \ - break; \ - case 256: \ - RUN_DEQUANT(256); \ - break; \ - default: \ - break; \ - } \ - } while(false) - - DISPATCH_Q_GROUP(); - #undef DISPATCH_Q_GROUP - #undef RUN_DEQUANT - - return out; -} - -// output is [n][k] (int32 dtype) -// input is [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2] -at::Tensor _unpack_int4_to_int( - const at::Tensor& packed_w, - int64_t innerKTiles) -{ - - c10::cuda::CUDAGuard g(packed_w.device()); - - TORCH_CHECK(packed_w.dim() == 4); - TORCH_CHECK(packed_w.dtype() == at::kInt); - TORCH_CHECK(packed_w.is_contiguous()); - - TORCH_CHECK(packed_w.size(2) == 32); - TORCH_CHECK(packed_w.size(3) == innerKTiles / 2); - TORCH_CHECK(innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8); - - int N = packed_w.size(0) * 8; - int K = packed_w.size(1) * innerKTiles * 16; - - constexpr int32_t kNTileSize = 8; - constexpr int32_t kKTileSize = 16; - - auto nTiles = divUp(N, kNTileSize); - - auto kSuperTiles = divUp(K, innerKTiles * kKTileSize); - - auto out = at::empty( - {N, K}, - at::TensorOptions().dtype(at::kInt).device(packed_w.device())); - - auto stream = at::cuda::getCurrentCUDAStream(); - dim3 grid(kSuperTiles, nTiles); - - if (innerKTiles == 2) { - _dequantize_int4_kernel<<>>( - packed_w.packed_accessor32(), - out.packed_accessor32()); - } - else if (innerKTiles == 4) { - _dequantize_int4_kernel<<>>( - packed_w.packed_accessor32(), - out.packed_accessor32()); - } else if (innerKTiles == 8) { - _dequantize_int4_kernel<<>>( - packed_w.packed_accessor32(), - out.packed_accessor32()); - } - - return out; -} - -TORCH_LIBRARY_IMPL(torchao, CUDA, m) { - m.impl("torchao::unpack_int4_to_int", &_unpack_int4_to_int); - m.impl("torchao::dequantize_int4", &_dequantize_int4); - -} diff --git a/torchao/csrc/unpack_int4.cpp b/torchao/csrc/unpack_int4.cpp deleted file mode 100644 index 87a8c176a4..0000000000 --- a/torchao/csrc/unpack_int4.cpp +++ /dev/null @@ -1,10 +0,0 @@ -#include -#include -#include - -TORCH_LIBRARY_FRAGMENT(torchao, m) { - m.impl_abstract_pystub("torchao.ops"); - m.def("unpack_int4_to_int(Tensor packed_w, int innerKTiles) -> Tensor"); - m.def("dequantize_int4(Tensor packed_w, Tensor scales_and_zeros, int group_size, int innerKTiles) -> Tensor"); - -} diff --git a/torchao/ops.py b/torchao/ops.py index 793cb93a97..f28e368b6c 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -46,7 +46,7 @@ def _(_in_feats, _weights, _scales, splitK = 1): -def unpack_int4_to_int(packed_w: Tensor, innerKTiles: int) -> Tensor: +def unpack_tensor_core_tiled_layout(packed_w: Tensor, innerKTiles: int) -> Tensor: """ Unpacks weights that were packed with `torch.ops.aten._convert_weight_to_int4pack` to original tensor of shape `N x K`. @@ -60,12 +60,12 @@ def unpack_int4_to_int(packed_w: Tensor, innerKTiles: int) -> Tensor: torch.tensor of shape is N x K, dtype is torch.int32 """ - return torch.ops.torchao.unpack_int4_to_int.default( + return torch.ops.torchao.unpack_tensor_core_tiled_layout.default( packed_w=packed_w, innerKTiles=innerKTiles ) -@register_custom_op(f"torchao::unpack_int4_to_int") +@register_custom_op(f"torchao::unpack_tensor_core_tiled_layout") def _(packed_w: Tensor, innerKTiles: int) -> Tensor: torch._check( packed_w.dim() == 4, @@ -89,7 +89,7 @@ def _(packed_w: Tensor, innerKTiles: int) -> Tensor: return torch.empty((N, K), dtype=torch.int32, device=packed_w.device) -def dequantize_int4(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, innerKTiles: int) -> Tensor: +def dequantize_tensor_core_tiled_layout(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, innerKTiles: int) -> Tensor: """ Dequantizes by: - Unpacking weights that were packed with `torch.ops.aten._convert_weight_to_int4pack` to original tensor of shape `N x K` @@ -111,12 +111,12 @@ def dequantize_int4(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, torch.tensor of shape is N x K, dtype is torch.bfloat16 """ - return torch.ops.torchao.dequantize_int4.default( + return torch.ops.torchao.dequantize_tensor_core_tiled_layout.default( packed_w, scales_and_zeros, group_size, innerKTiles ) -@register_custom_op(f"torchao::dequantize_int4") +@register_custom_op(f"torchao::dequantize_tensor_core_tiled_layout") def _(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, innerKTiles: int) -> Tensor: # packed_w preconditions torch._check( @@ -147,4 +147,4 @@ def _(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, innerKTiles: torch._check(scales_and_zeros.size(1) == N, lambda: "scales_and_zeros must have N at dim 1") torch._check(scales_and_zeros.size(2) == 2, lambda: "scales_and_zeros must have 2 at dim 2") - return torch.empty((N, K), dtype=torch.bfloat16, device=packed_w.device) + return torch.empty((N, K), dtype=torch.bfloat16, device=packed_w.device) \ No newline at end of file From b6ad9f745e1c75196460edc3e49973dbae7e9bd1 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Fri, 28 Jun 2024 00:55:53 +0000 Subject: [PATCH 12/19] add renamed kernel source --- .../tensor_core_tiled_layout.cu | 312 ++++++++++++++++++ torchao/csrc/tensor_core_tiled_layout.cpp | 10 + 2 files changed, 322 insertions(+) create mode 100644 torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu create mode 100644 torchao/csrc/tensor_core_tiled_layout.cpp diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu new file mode 100644 index 0000000000..652bba5ca6 --- /dev/null +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -0,0 +1,312 @@ +#include +#include +#include +#include +#include +#include + +template +constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) { + static_assert(std::is_integral::value && std::is_integral::value, ""); + const uint64_t blocks = a / b + (a % b != 0); + return blocks; +} +constexpr int32_t kWarpSize = 32; + +//Simple data structure to represent 4 pairs of bfloat16s, used for vectorized dequantization +//https://github.com/pytorch/pytorch/blob/b6689e0fb83a1578959ab0d9c6d2d9e11f7df21a/aten/src/ATen/native/cuda/int4mm.cu#L178-L180 +struct __align__(16) bf16x2x4 { + __nv_bfloat162 vals[4]; +}; + +//Copied from https://github.com/pytorch/pytorch/blob/b6689e0fb83a1578959ab0d9c6d2d9e11f7df21a/aten/src/ATen/native/cuda/int4mm.cu#L195C1-L241C1 +inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) { + bf16x2x4 result; + constexpr int kElements = 8; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const source_i4s = source; + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; + + // We don't have enough mantissa to remove as much shift overhead as FP16, so + // we must loop. No shift needed for first item. + uint32_t i4s = source_i4s; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); +#pragma unroll + for (int ii = 1; ii < kElements / 2; ++ii) { + i4s >>= 4; // or is it 8? + // (i4s & 0x000f000f) | 0x43004300 + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[ii]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + } + + // This is the BF16 {-136, -136} represented as an integer. + static constexpr uint32_t BF16_BIAS = 0xC308C308; + static constexpr uint32_t BF16_ONE = 0x3F803F80; + +// Finally, we construct the output numbers. +#pragma unroll + for (int ii = 0; ii < kElements / 2; ++ii) { + // Since this section is for Ampere+, we use bf16 fma to do the bias + // subtraction + asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[ii]) + : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); + } + + return result; +} +// in size [ceil(n / 8)][ceil(k / (InnerKTiles * 16))][32][InnerKTiles / 2] +// scales_and_zeros size [numQGroups][n][2] +// out size [n][k] +template +__global__ void _dequantize_int4_kernel( + const at::PackedTensorAccessor32 in, + at::PackedTensorAccessor32 out, + at::optional> scales_and_zeros = c10::nullopt) +{ + + constexpr int32_t kNTileSize = 8; + constexpr int32_t kKTileSize = 16; + + auto kOuterTile = blockIdx.x; + auto nTile = blockIdx.y; + auto t = threadIdx.x; + + // n dimension that this lane loads from + auto n0 = nTile * kNTileSize + (t / 4); + + // 8 k-tile values, 4 per m16n8k16 mma.sync operand B + // int32_t ks[8]; + //Only need 4 offsets since TC layout for single tile is 2x2 (2 pairs of 2 contiguous values) + int32_t ks[4]; + + // Store address base offset + auto pOut = &out[n0][0]; + +// Unpack 2 k-tiles at a time since min pack size is InnerKTiles = 2 +#pragma unroll + for (int innerKTile = 0; innerKTile < InnerKTiles; innerKTile += 2) { + //Tensor-core layout for m16n8k16 is such that each tile has 2 pairs of 2 contiguous values + //Hence, we only need 4 offsets + // Offsets of innerTile0 + auto kBase0 = (kOuterTile * InnerKTiles + innerKTile) * kKTileSize; + ks[0] = kBase0 + (t % 4) * 2; + ks[1] = ks[0] + 8; + + // Offsets of innerTile1 + auto kBase1 = kBase0 + kKTileSize; + ks[2] = kBase1 + (t % 4) * 2; + ks[3] = ks[2] + 8; + + // inner k-tiles unpack two at a time + int32_t pack = in[nTile][kOuterTile][t][innerKTile / 2]; + + if constexpr(kDequant) { + // static_assert(scales_and_zeros.has_value(), "scales_and_zeros must be set when dequantizing"); + static_assert(std::is_same::value, "Out must be BFloat16 when dequantizing"); + // __nv_bfloat16 v[8]; + + // // Extract u4, convert to s4 by subtracting by 2 ** nbits / 2, then convert to bfloat16 + bf16x2x4 v_bf16x2x4 = convert_i4x8_to_bf16x2x4(pack); + + // All b values within a 16x16 tile should fall within the same q group + // Hence we load 1 scale and zero per loop + int qgroup = ks[0] / groupSize; + const __nv_bfloat16 *pSZ = reinterpret_cast(&scales_and_zeros.value()[qgroup][n0][0]); + + // Vectorize scales and zeros + __nv_bfloat162 scale2 = __bfloat162bfloat162(pSZ[0]); + __nv_bfloat162 zero2 = __bfloat162bfloat162(pSZ[1]); + + #pragma unroll + for (int i = 0; i < 4; i++) { + reinterpret_cast<__nv_bfloat162*>(&pOut[ks[i]])[0] = __hfma2(v_bf16x2x4.vals[i], scale2, zero2); + } + } + else { + static_assert(std::is_same::value, "Out must be int32_t when unpacking to int"); + int32_t v[8]; + + v[0] = pack & 0x0000000f; + v[2] = (pack >> 4) & 0x0000000f; + v[4] = (pack >> 8) & 0x0000000f; + v[6] = (pack >> 12) & 0x0000000f; + v[1] = (pack >> 16) & 0x0000000f; + v[3] = (pack >> 20) & 0x0000000f; + v[5] = (pack >> 24) & 0x0000000f; + v[7] = (pack >> 28) & 0x0000000f; + int2* v_i32x2 = reinterpret_cast(v); + + #pragma unroll + for (int i = 0; i < 4; ++i) { + reinterpret_cast(&pOut[ks[i]])[0] = v_i32x2[i]; + } + } + } +} + +// output is [n][k] (int32 dtype) +// input is [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2] +// scales_and_zeros is [numQGroups][n][2] +// qGroupSize is 32, 64, 128 or 256 +at::Tensor _dequantize_tensor_core_tiled_layout( + const at::Tensor& packed_w, + const at::Tensor& scales_and_zeros, + int64_t group_size, + int64_t innerKTiles) +{ + + constexpr int32_t kNTileSize = 8; + constexpr int32_t kKTileSize = 16; + + c10::cuda::CUDAGuard g(packed_w.device()); + + // packed_w preconditions + TORCH_CHECK(packed_w.dim() == 4); + TORCH_CHECK(packed_w.dtype() == at::kInt); + TORCH_CHECK(packed_w.is_contiguous()); + TORCH_CHECK(packed_w.size(2) == 32); + TORCH_CHECK(packed_w.size(3) == innerKTiles / 2); + TORCH_CHECK(innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8); + + auto numQGroups = scales_and_zeros.size(0); + int N = packed_w.size(0) * kNTileSize; + int K = packed_w.size(1) * innerKTiles * kKTileSize; + + // scales_and_zeros preconditions + TORCH_CHECK( + group_size == 32 || group_size == 64 || group_size == 128 || + group_size == 256); + TORCH_CHECK(numQGroups == K / group_size); + TORCH_CHECK(scales_and_zeros.dim() == 3); + TORCH_CHECK(scales_and_zeros.size(1) == N); + TORCH_CHECK(scales_and_zeros.size(2) == 2); + + auto nTiles = divUp(N, kNTileSize); + auto kSuperTiles = divUp(K, innerKTiles * kKTileSize); + auto out = at::empty( + {N, K}, + at::TensorOptions().dtype(at::kBFloat16).device(packed_w.device())); + + auto stream = at::cuda::getCurrentCUDAStream(); + dim3 grid(kSuperTiles, nTiles); + +#define RUN_DEQUANT(QGROUPSIZE) \ + do { \ + switch(innerKTiles) { \ + case 2: \ + _dequantize_int4_kernel<<>>( \ + packed_w.packed_accessor32(), \ + out.packed_accessor32(), \ + scales_and_zeros.packed_accessor32()); \ + break; \ + case 4: \ + _dequantize_int4_kernel<<>>( \ + packed_w.packed_accessor32(), \ + out.packed_accessor32(), \ + scales_and_zeros.packed_accessor32()); \ + break; \ + case 8: \ + _dequantize_int4_kernel<<>>( \ + packed_w.packed_accessor32(), \ + out.packed_accessor32(), \ + scales_and_zeros.packed_accessor32()); \ + break; \ + default: \ + break; \ + } \ + } while(false) + +#define DISPATCH_Q_GROUP() \ + do { \ + switch (group_size) { \ + case 32: \ + RUN_DEQUANT(32); \ + break; \ + case 64: \ + RUN_DEQUANT(64); \ + break; \ + case 128: \ + RUN_DEQUANT(128); \ + break; \ + case 256: \ + RUN_DEQUANT(256); \ + break; \ + default: \ + break; \ + } \ + } while(false) + + DISPATCH_Q_GROUP(); + #undef DISPATCH_Q_GROUP + #undef RUN_DEQUANT + + return out; +} + +// output is [n][k] (int32 dtype) +// input is [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2] +at::Tensor _unpack_tensor_core_tiled_layout( + const at::Tensor& packed_w, + int64_t innerKTiles) +{ + + c10::cuda::CUDAGuard g(packed_w.device()); + + TORCH_CHECK(packed_w.dim() == 4); + TORCH_CHECK(packed_w.dtype() == at::kInt); + TORCH_CHECK(packed_w.is_contiguous()); + + TORCH_CHECK(packed_w.size(2) == 32); + TORCH_CHECK(packed_w.size(3) == innerKTiles / 2); + TORCH_CHECK(innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8); + + int N = packed_w.size(0) * 8; + int K = packed_w.size(1) * innerKTiles * 16; + + constexpr int32_t kNTileSize = 8; + constexpr int32_t kKTileSize = 16; + + auto nTiles = divUp(N, kNTileSize); + + auto kSuperTiles = divUp(K, innerKTiles * kKTileSize); + + auto out = at::empty( + {N, K}, + at::TensorOptions().dtype(at::kInt).device(packed_w.device())); + + auto stream = at::cuda::getCurrentCUDAStream(); + dim3 grid(kSuperTiles, nTiles); + + if (innerKTiles == 2) { + _dequantize_int4_kernel<<>>( + packed_w.packed_accessor32(), + out.packed_accessor32()); + } + else if (innerKTiles == 4) { + _dequantize_int4_kernel<<>>( + packed_w.packed_accessor32(), + out.packed_accessor32()); + } else if (innerKTiles == 8) { + _dequantize_int4_kernel<<>>( + packed_w.packed_accessor32(), + out.packed_accessor32()); + } + + return out; +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::unpack_tensor_core_tiled_layout", &_unpack_tensor_core_tiled_layout); + m.impl("torchao::dequantize_tensor_core_tiled_layout", &_dequantize_tensor_core_tiled_layout); + +} diff --git a/torchao/csrc/tensor_core_tiled_layout.cpp b/torchao/csrc/tensor_core_tiled_layout.cpp new file mode 100644 index 0000000000..a9c6e65280 --- /dev/null +++ b/torchao/csrc/tensor_core_tiled_layout.cpp @@ -0,0 +1,10 @@ +#include +#include +#include + +TORCH_LIBRARY_FRAGMENT(torchao, m) { + m.impl_abstract_pystub("torchao.ops"); + m.def("unpack_tensor_core_tiled_layout(Tensor packed_w, int innerKTiles) -> Tensor"); + m.def("dequantize_tensor_core_tiled_layout(Tensor packed_w, Tensor scales_and_zeros, int group_size, int innerKTiles) -> Tensor"); + +} From 612d8e3906f69eb05ab72cdca659c42f4d7045fd Mon Sep 17 00:00:00 2001 From: jeromeku Date: Fri, 28 Jun 2024 01:05:02 +0000 Subject: [PATCH 13/19] add back test_aot_dispatch opcheck --- test/test_ops.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index adc4572f09..cc0b48090b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -103,7 +103,7 @@ def test_unpack_tensor_core_tiled_layout_op(shape, innerKTiles): "test_schema", "test_autograd_registration", "test_faketensor", - # "test_aot_dispatch_dynamic", + "test_aot_dispatch_dynamic", ] t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") packed_w = torch.ops.aten._convert_weight_to_int4pack(t, innerKTiles) @@ -204,7 +204,6 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, innerKTiles, group_size): q = torch.randint(0, 16, shape, dtype=torch.int, device=device) packed_w = torch._convert_weight_to_int4pack(q, innerKTiles) - print(packed_w.shape) q_groups = k // group_size scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device) zeros = torch.randn_like(scales) @@ -214,7 +213,7 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, innerKTiles, group_size): "test_schema", "test_autograd_registration", "test_faketensor", - # "test_aot_dispatch_dynamic", + "test_aot_dispatch_dynamic", ] opcheck( torch.ops.torchao.dequantize_tensor_core_tiled_layout, From 9afa73e6b093dab76d519da73d8ab38c43c3d288 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Wed, 3 Jul 2024 19:33:26 +0000 Subject: [PATCH 14/19] rename innerKTiles to inner_k_tiles --- test/test_ops.py | 54 ++++++++++++----------- torchao/csrc/tensor_core_tiled_layout.cpp | 4 +- torchao/ops.py | 45 ++++++++++--------- 3 files changed, 53 insertions(+), 50 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index cc0b48090b..4e308ed9f4 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1,20 +1,22 @@ import itertools +import unittest + +import pytest import torch -from torch.testing._internal.common_utils import TestCase, IS_FBCODE +from parameterized import parameterized +from torch.testing._internal.common_utils import IS_FBCODE, TestCase from torch.testing._internal.optests import opcheck + import torchao +import torchao.quantization from torchao.prototype.fp6_llm.fp6_llm import from_tc_float6_e3m2 -import unittest -from parameterized import parameterized -import pytest from torchao.quantization.utils import ( get_groupwise_affine_qparams, - groupwise_affine_quantize_tensor_from_qparams, groupwise_affine_dequantize_tensor_from_qparams, + groupwise_affine_quantize_tensor_from_qparams, pack_tinygemm_scales_and_zeros, - unpack_tinygemm_scales_and_zeros + unpack_tinygemm_scales_and_zeros, ) -import torchao.quantization try: import torchao.ops @@ -84,21 +86,21 @@ def test_fp6_llm_linear_correctness(self, BS, OC, IC, splitK): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") -@pytest.mark.parametrize("shape, innerKTiles", TEST_CONFIGS_UNPACK, ids=str) -def test_unpack_tensor_core_tiled_layout_correctness(shape, innerKTiles): +@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=str) +def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles): N, K = shape - assert K % (innerKTiles * kTileSizeK) == 0 and N % kTileSizeN == 0 + assert K % (inner_k_tiles * kTileSizeK) == 0 and N % kTileSizeN == 0 t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") - packed_w = torch.ops.aten._convert_weight_to_int4pack(t, innerKTiles) - unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed_w, innerKTiles) - assert torch.allclose(t, unpacked) + packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles) + unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed_w, inner_k_tiles) + assert torch.equal(t, unpacked) # TODO: Fix "test_aot_dispatch_dynamic" test failure @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") -@pytest.mark.parametrize("shape, innerKTiles", TEST_CONFIGS_UNPACK , ids=str) -def test_unpack_tensor_core_tiled_layout_op(shape, innerKTiles): +@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK , ids=str) +def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles): test_utils = [ "test_schema", "test_autograd_registration", @@ -106,11 +108,11 @@ def test_unpack_tensor_core_tiled_layout_op(shape, innerKTiles): "test_aot_dispatch_dynamic", ] t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") - packed_w = torch.ops.aten._convert_weight_to_int4pack(t, innerKTiles) + packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles) opcheck( torch.ops.torchao.unpack_tensor_core_tiled_layout, - (packed_w, innerKTiles), + (packed_w, inner_k_tiles), test_utils=test_utils, ) @@ -136,8 +138,8 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") -@pytest.mark.parametrize("shape, innerKTiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) -def test_dequantize_tensor_core_tiled_layout_correctness(shape, innerKTiles, group_size): +@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) +def test_dequantize_tensor_core_tiled_layout_correctness(shape, inner_k_tiles, group_size): n, k = shape dtype = torch.bfloat16 @@ -145,7 +147,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness(shape, innerKTiles, gro nTileSize = 8 kTileSize = 16 nTiles = n // nTileSize - kTiles = k // (innerKTiles * kTileSize) + kTiles = k // (inner_k_tiles * kTileSize) numThreads = 32 device = "cuda" @@ -159,7 +161,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness(shape, innerKTiles, gro ) # Pack to tensor core layout - packed = torch.ops.aten._convert_weight_to_int4pack(q, innerKTiles) + packed = torch.ops.aten._convert_weight_to_int4pack(q, inner_k_tiles) scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) q_groups = k // group_size assert scales_and_zeros.shape == torch.Size([q_groups, n, 2]) @@ -178,7 +180,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness(shape, innerKTiles, gro ).t() # Actual operation to test - dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, innerKTiles) + dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, inner_k_tiles) # Compare results diff_ao_id = (dq_id - dq_ao).abs().max() @@ -197,13 +199,13 @@ def test_dequantize_tensor_core_tiled_layout_correctness(shape, innerKTiles, gro @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") -@pytest.mark.parametrize("shape, innerKTiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) -def test_dequantize_tensor_core_tiled_layout_op(shape, innerKTiles, group_size): +@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) +def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size): n, k = shape device = "cuda" q = torch.randint(0, 16, shape, dtype=torch.int, device=device) - packed_w = torch._convert_weight_to_int4pack(q, innerKTiles) + packed_w = torch._convert_weight_to_int4pack(q, inner_k_tiles) q_groups = k // group_size scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device) zeros = torch.randn_like(scales) @@ -217,7 +219,7 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, innerKTiles, group_size): ] opcheck( torch.ops.torchao.dequantize_tensor_core_tiled_layout, - (packed_w, scales_and_zeros, group_size, innerKTiles), + (packed_w, scales_and_zeros, group_size, inner_k_tiles), test_utils=test_utils, ) diff --git a/torchao/csrc/tensor_core_tiled_layout.cpp b/torchao/csrc/tensor_core_tiled_layout.cpp index a9c6e65280..203d5d50c0 100644 --- a/torchao/csrc/tensor_core_tiled_layout.cpp +++ b/torchao/csrc/tensor_core_tiled_layout.cpp @@ -4,7 +4,7 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.impl_abstract_pystub("torchao.ops"); - m.def("unpack_tensor_core_tiled_layout(Tensor packed_w, int innerKTiles) -> Tensor"); - m.def("dequantize_tensor_core_tiled_layout(Tensor packed_w, Tensor scales_and_zeros, int group_size, int innerKTiles) -> Tensor"); + m.def("unpack_tensor_core_tiled_layout(Tensor packed_w, int inner_k_tiles) -> Tensor"); + m.def("dequantize_tensor_core_tiled_layout(Tensor packed_w, Tensor scales_and_zeros, int group_size, int inner_k_tiles) -> Tensor"); } diff --git a/torchao/ops.py b/torchao/ops.py index f28e368b6c..2afa5ebf73 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -1,5 +1,6 @@ import torch from torch import Tensor + from torchao.utils import TORCH_VERSION_AFTER_2_4 @@ -46,27 +47,27 @@ def _(_in_feats, _weights, _scales, splitK = 1): -def unpack_tensor_core_tiled_layout(packed_w: Tensor, innerKTiles: int) -> Tensor: +def unpack_tensor_core_tiled_layout(packed_w: Tensor, inner_k_tiles: int) -> Tensor: """ Unpacks weights that were packed with `torch.ops.aten._convert_weight_to_int4pack` to original tensor of shape `N x K`. - Assumes that the packed weights were generated with `torch.ops.aten._convert_weight_to_int4pack` with `innerKTiles = 2 | 4 | 8`" + Assumes that the packed weights were generated with `torch.ops.aten._convert_weight_to_int4pack` with `inner_k_tiles = 2 | 4 | 8`" Args: - packed_w: torch.tensor: 4D tensor with shape (N / 8) x (K / (innerKTiles * 16)) x 32 x innerKTiles, dtype is torch.int32 - innerKTiles: int + packed_w: torch.tensor: 4D tensor with shape (N / 8) x (K / (inner_k_tiles * 16)) x 32 x inner_k_tiles, dtype is torch.int32 + inner_k_tiles: int Returns: torch.tensor of shape is N x K, dtype is torch.int32 """ return torch.ops.torchao.unpack_tensor_core_tiled_layout.default( - packed_w=packed_w, innerKTiles=innerKTiles + packed_w=packed_w, inner_k_tiles=inner_k_tiles ) @register_custom_op(f"torchao::unpack_tensor_core_tiled_layout") -def _(packed_w: Tensor, innerKTiles: int) -> Tensor: +def _(packed_w: Tensor, inner_k_tiles: int) -> Tensor: torch._check( packed_w.dim() == 4, lambda: f"packed weight should be a 42d tensor, got {packed_w.dim()}D", @@ -76,20 +77,20 @@ def _(packed_w: Tensor, innerKTiles: int) -> Tensor: lambda: f"weight must be INT32, got {packed_w.dtype}", ) torch._check( - innerKTiles == 2 or innerKTiles == 4 or innerKTiles == 8, - lambda: "innerKTiles must be 2, 4, or 8", + inner_k_tiles == 2 or inner_k_tiles == 4 or inner_k_tiles == 8, + lambda: "inner_k_tiles must be 2, 4, or 8", ) torch._check(packed_w.size(2) == 32, lambda: "packed weight must have 32 at dim 2") torch._check( - packed_w.size(3) == innerKTiles / 2, - lambda: "packed weight must have innerKTiles/2 at dim 3", + packed_w.size(3) == inner_k_tiles / 2, + lambda: "packed weight must have inner_k_tiles/2 at dim 3", ) N = packed_w.size(0) * 8 - K = packed_w.size(1) * innerKTiles * 16 + K = packed_w.size(1) * inner_k_tiles * 16 return torch.empty((N, K), dtype=torch.int32, device=packed_w.device) -def dequantize_tensor_core_tiled_layout(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, innerKTiles: int) -> Tensor: +def dequantize_tensor_core_tiled_layout(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles: int) -> Tensor: """ Dequantizes by: - Unpacking weights that were packed with `torch.ops.aten._convert_weight_to_int4pack` to original tensor of shape `N x K` @@ -97,27 +98,27 @@ def dequantize_tensor_core_tiled_layout(packed_w: Tensor, scales_and_zeros: Tens - Dequantizing with the scales_and_zeros that were packed with `torchao.quantization.utils.pack_tinygemm_scales_and_zeros` Assumes: - - packed weights were generated with `torch.ops.aten._convert_weight_to_int4pack` with `innerKTiles = 2 | 4 | 8`" + - packed weights were generated with `torch.ops.aten._convert_weight_to_int4pack` with `inner_k_tiles = 2 | 4 | 8`" - packed scales_and_zeros were generated with `torchao.quantization.utils.pack_tinygemm_scales_and_zeros` - qGroupSize is 32 | 64 | 128 | 256 Args: - packed_w: torch.tensor: 4D tensor with shape `(N / 8) x (K / (innerKTiles * 16)) x 32 x innerKTiles / 2`, dtype is torch.int32 + packed_w: torch.tensor: 4D tensor with shape `(N / 8) x (K / (inner_k_tiles * 16)) x 32 x inner_k_tiles / 2`, dtype is torch.int32 scales_and_zeros: torch.tensor: 3D tensor with shape `numQGroups x N x 2`, dtype is torch.bfloat16 where numQGroups is K / qGroupSize qGroupSize: int - innerKTiles: int + inner_k_tiles: int Returns: torch.tensor of shape is N x K, dtype is torch.bfloat16 """ return torch.ops.torchao.dequantize_tensor_core_tiled_layout.default( - packed_w, scales_and_zeros, group_size, innerKTiles + packed_w, scales_and_zeros, group_size, inner_k_tiles ) @register_custom_op(f"torchao::dequantize_tensor_core_tiled_layout") -def _(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, innerKTiles: int) -> Tensor: +def _(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles: int) -> Tensor: # packed_w preconditions torch._check( packed_w.dim() == 4, @@ -128,16 +129,16 @@ def _(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, innerKTiles: lambda: f"weight must be INT32, got {packed_w.dtype}", ) torch._check( - innerKTiles == 2 or innerKTiles == 4 or innerKTiles == 8, - lambda: "innerKTiles must be 2, 4, or 8", + inner_k_tiles == 2 or inner_k_tiles == 4 or inner_k_tiles == 8, + lambda: "inner_k_tiles must be 2, 4, or 8", ) torch._check(packed_w.size(2) == 32, lambda: "packed weight must have 32 at dim 2") torch._check( - packed_w.size(3) == innerKTiles / 2, - lambda: "packed weight must have innerKTiles/2 at dim 3", + packed_w.size(3) == inner_k_tiles / 2, + lambda: "packed weight must have inner_k_tiles/2 at dim 3", ) N = packed_w.size(0) * 8 - K = packed_w.size(1) * innerKTiles * 16 + K = packed_w.size(1) * inner_k_tiles * 16 # scales_and_zeros preconditions torch._check(scales_and_zeros.dtype is torch.bfloat16, lambda: "scales_and_zeros must be bfloat16") From f05c720f30afb900a17c65defba7d1c8925515df Mon Sep 17 00:00:00 2001 From: jeromeku Date: Wed, 3 Jul 2024 19:59:38 +0000 Subject: [PATCH 15/19] add unpack and dequant test --- test/test_ops.py | 48 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 4e308ed9f4..044748854c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -139,16 +139,9 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") @pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) -def test_dequantize_tensor_core_tiled_layout_correctness(shape, inner_k_tiles, group_size): +def test_dequantize_tensor_core_tiled_layout_correctness_tinygemm(shape, inner_k_tiles, group_size): n, k = shape dtype = torch.bfloat16 - - # tinygemm params - nTileSize = 8 - kTileSize = 16 - nTiles = n // nTileSize - kTiles = k // (inner_k_tiles * kTileSize) - numThreads = 32 device = "cuda" @@ -166,6 +159,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness(shape, inner_k_tiles, g q_groups = k // group_size assert scales_and_zeros.shape == torch.Size([q_groups, n, 2]) + # Dequantize 'ao' ref dq_ao = groupwise_affine_dequantize_tensor_from_qparams( q, scales, zeros, n_bit=4, groupsize=group_size ) @@ -187,16 +181,48 @@ def test_dequantize_tensor_core_tiled_layout_correctness(shape, inner_k_tiles, g diff_op_id = (dq_op - dq_id).abs().max() diff_op_ao = (dq_op - dq_ao).abs().max() - # There are slight numerical differences when dequantizing with an identity matrix - # Since the `dequantize_int4` kernel relies on same underlying numerical conversions, this gives same - # numerical differences when compared to the `groupwise_affine_dequantize` + # There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize` + # Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast + # conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are + # expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`. # Test that the `dequant` kernel gives same results as identity matrix-based dequant assert diff_op_id == 0 # Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix assert diff_op_ao == diff_ao_id + + assert diff_op_ao < 1e-1 + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") +@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) +def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shape, inner_k_tiles, group_size): + n, k = shape + dtype = torch.bfloat16 + device = "cuda" + + # Quantize and pack + t = torch.randn(n, k, dtype=dtype, device=device) + scales, zeros = get_groupwise_affine_qparams(t, n_bit=4, groupsize=group_size, dtype=dtype) + q = groupwise_affine_quantize_tensor_from_qparams( + t, scales, zeros, n_bit=4, groupsize=group_size + ) + + packed = torch.ops.aten._convert_weight_to_int4pack(q, inner_k_tiles) + scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) + + # Unpack and dequantize + unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed, inner_k_tiles) + dq_ao = groupwise_affine_dequantize_tensor_from_qparams( + unpacked, scales, zeros, n_bit=4, groupsize=group_size + ) + # Actual operation to test + dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, inner_k_tiles) + assert torch.allclose(dq_op, dq_ao, atol=1e-1) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") @pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) From c666a18bea01d9cc99a7037eaf6b777212d3599c Mon Sep 17 00:00:00 2001 From: jeromeku Date: Wed, 3 Jul 2024 20:37:11 +0000 Subject: [PATCH 16/19] additional numerical checks for unpack then dequant --- test/test_ops.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 044748854c..5cb8723aaa 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -139,7 +139,7 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") @pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) -def test_dequantize_tensor_core_tiled_layout_correctness_tinygemm(shape, inner_k_tiles, group_size): +def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, inner_k_tiles, group_size): n, k = shape dtype = torch.bfloat16 @@ -194,6 +194,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_tinygemm(shape, inner_k assert diff_op_ao < 1e-1 +# This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") @pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) @@ -218,10 +219,35 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shap unpacked, scales, zeros, n_bit=4, groupsize=group_size ) + # Dequantize by passing in an identity matrix as the activation + a_eye = torch.eye(k, device=device, dtype=dtype) + dq_id = torch.ops.aten._weight_int4pack_mm( + a_eye, + packed, + group_size, + scales_and_zeros, + ).t() + # Actual operation to test dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, inner_k_tiles) - assert torch.allclose(dq_op, dq_ao, atol=1e-1) + + # Compare results + diff_ao_id = (dq_id - dq_ao).abs().max() + diff_op_id = (dq_op - dq_id).abs().max() + diff_op_ao = (dq_op - dq_ao).abs().max() + + # There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize` + # Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast + # conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are + # expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`. + + # Test that the `dequant` kernel gives same results as identity matrix-based dequant + assert diff_op_id == 0 + + # Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix + assert diff_op_ao == diff_ao_id + assert diff_op_ao < 1e-1 @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") From e8ca817979b4ebbd6f9ecb5f2397d50bece25f24 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Thu, 4 Jul 2024 00:16:17 +0000 Subject: [PATCH 17/19] rebase test_ops on main --- test/test_ops.py | 97 +++++++++++++++++++++++++----------------------- 1 file changed, 51 insertions(+), 46 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 5cb8723aaa..58296c4f92 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1,15 +1,27 @@ import itertools -import unittest -import pytest +import torchao + import torch -from parameterized import parameterized -from torch.testing._internal.common_utils import IS_FBCODE, TestCase +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) from torch.testing._internal.optests import opcheck +from torchao.utils import is_fbcode +# from torchao.prototype.quant_llm import from_scaled_tc_fpx +import pytest + +if is_fbcode(): + pytest.skip("Skipping the test in fbcode since we don't have TARGET file for kernels") + +try: + import torchao.ops +except RuntimeError: + pytest.skip("torchao.ops not available") -import torchao -import torchao.quantization -from torchao.prototype.fp6_llm.fp6_llm import from_tc_float6_e3m2 from torchao.quantization.utils import ( get_groupwise_affine_qparams, groupwise_affine_dequantize_tensor_from_qparams, @@ -18,53 +30,51 @@ unpack_tinygemm_scales_and_zeros, ) -try: - import torchao.ops -except RuntimeError: - pytest.skip("torchao.ops not available") - -# torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...): -# test_faketensor failed with module 'torch' has no attribute '_custom_ops' (scroll up for stack trace) -@pytest.mark.filterwarnings("ignore:create_unbacked_symint is deprecated, please use new_dynamic_size instead:UserWarning") -@unittest.skipIf(IS_FBCODE, "Skipping the test in fbcode since we don't have TARGET file for kernels") class TestOps(TestCase): - def _create_fp6_inputs(self, BS: int, OC: int, IC: int, device): - # Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t. - fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int) - fp16_scale = torch.rand(OC).half() + 0.5 - fp16_activation = torch.rand(BS, IC).half() + 0.5 - return fp6_weight.to(device), fp16_scale.to(device), fp16_activation.to(device) - - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - def test_fp6_llm_linear(self): + def _create_fpx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device): + # Randomly initialize each byte + nbits = 1 + ebits + mbits + fpx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8) + scale = torch.rand(OC).half() + 0.5 + fp16_act = torch.rand(BS, IC).half() + 0.5 + return fpx_weight.to(device), scale.to(device), fp16_act.to(device) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @parametrize("ebits,mbits", [(3, 2), (2, 2)]) + def test_quant_llm_linear(self, ebits, mbits): BS = 2 OC = 256 IC = 256 splitK = 1 - fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda") + fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda") # smoke test - torchao.ops.fp6_llm_linear(fp16_activation, fp6_weight, fp16_scale, splitK) + torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK) # comprehensive testing test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] - opcheck(torch.ops.torchao.fp6_llm_linear, (fp16_activation, fp6_weight, fp16_scale, splitK), test_utils=test_utils) + opcheck(torch.ops.torchao.quant_llm_linear, (ebits, mbits, fp16_act, fpx_weight, scale, splitK), test_utils=test_utils) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) + @parametrize("ebits,mbits", [(3, 2), (2, 2)]) + def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK): + # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/tests/python/kernel_test_fpx.py + fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda") - # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py - @parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - def test_fp6_llm_linear_correctness(self, BS, OC, IC, splitK): - fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda") + results_fpx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK) - results_fp6 = torchao.ops.fp6_llm_linear(fp16_activation, fp6_weight, fp16_scale, splitK) + fp16_weight = from_scaled_tc_fpx(fpx_weight, ebits, mbits, scale).half() + results_fp16 = fp16_act @ fp16_weight.T - fp16_weight = from_tc_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None] - results_fp16 = fp16_activation @ fp16_weight.T + error = (results_fpx - results_fp16).abs().mean() + gt = results_fp16.abs().mean() + relative_error = error / gt + assert relative_error < 1e-3 + +instantiate_parametrized_tests(TestOps) - error = (results_fp6 - results_fp16).abs() - relative_error = error / results_fp16.abs() - assert relative_error.mean() < 1e-2 ## Tests for `unpack_int4_packed` kTileSizeN = 8 @@ -85,7 +95,6 @@ def test_fp6_llm_linear_correctness(self, BS, OC, IC, splitK): TEST_CONFIGS_DEQUANT = list(itertools.product(SHAPES, INNERKTILES, QGROUP_SIZES)) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") @pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=str) def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles): N, K = shape @@ -98,7 +107,6 @@ def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles): # TODO: Fix "test_aot_dispatch_dynamic" test failure @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") @pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK , ids=str) def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles): test_utils = [ @@ -137,7 +145,6 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") @pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, inner_k_tiles, group_size): n, k = shape @@ -196,7 +203,6 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, in # This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") @pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shape, inner_k_tiles, group_size): n, k = shape @@ -250,7 +256,6 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shap assert diff_op_ao < 1e-1 @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") @pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size): n, k = shape @@ -261,7 +266,7 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size q_groups = k // group_size scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device) zeros = torch.randn_like(scales) - scales_and_zeros = torchao.quantization.utils.pack_tinygemm_scales_and_zeros(scales, zeros) + scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) test_utils = [ "test_schema", @@ -276,4 +281,4 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size ) if __name__ == "__main__": - unittest.main() + run_tests() From e089ffb4ad5d21097c61ba6822b078a216aa13d0 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Thu, 4 Jul 2024 00:17:30 +0000 Subject: [PATCH 18/19] remove commented out code --- test/test_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index 58296c4f92..5138a1e2ed 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -11,7 +11,7 @@ ) from torch.testing._internal.optests import opcheck from torchao.utils import is_fbcode -# from torchao.prototype.quant_llm import from_scaled_tc_fpx +from torchao.prototype.quant_llm import from_scaled_tc_fpx import pytest if is_fbcode(): From 75df5f5059e289cc7b771bc4cc47b6a64214d4a8 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Thu, 4 Jul 2024 00:56:23 +0000 Subject: [PATCH 19/19] skip dynamic opcheck unless torch>=2.5 --- test/test_ops.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 5138a1e2ed..45a10abe3a 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -10,7 +10,7 @@ run_tests, ) from torch.testing._internal.optests import opcheck -from torchao.utils import is_fbcode +from torchao.utils import is_fbcode, TORCH_VERSION_AFTER_2_5 from torchao.prototype.quant_llm import from_scaled_tc_fpx import pytest @@ -76,7 +76,7 @@ def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK): instantiate_parametrized_tests(TestOps) -## Tests for `unpack_int4_packed` +## Tests for `tensor_core_layout` kTileSizeN = 8 kTileSizeK = 16 @@ -113,8 +113,12 @@ def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles): "test_schema", "test_autograd_registration", "test_faketensor", - "test_aot_dispatch_dynamic", ] + + # TODO: Figure out why test fails unless torch >= 2.5 + if TORCH_VERSION_AFTER_2_5: + test_utils.append("test_aot_dispatch_dynamic") + t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles) @@ -272,8 +276,10 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size "test_schema", "test_autograd_registration", "test_faketensor", - "test_aot_dispatch_dynamic", ] + # TODO: Figure out why test fails unless torch >= 2.5 + if TORCH_VERSION_AFTER_2_5: + test_utils.append("test_aot_dispatch_dynamic") opcheck( torch.ops.torchao.dequantize_tensor_core_tiled_layout, (packed_w, scales_and_zeros, group_size, inner_k_tiles), @@ -281,4 +287,4 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size ) if __name__ == "__main__": - run_tests() + run_tests() \ No newline at end of file