diff --git a/test/dtypes/test_bitnet.py b/test/dtypes/test_bitnet.py deleted file mode 100644 index e248b04b05..0000000000 --- a/test/dtypes/test_bitnet.py +++ /dev/null @@ -1,91 +0,0 @@ -import pytest -import torch -import torch.nn as nn - -from torchao.prototype.dtypes import BitnetTensor -from torchao.prototype.dtypes.uint2 import unpack_uint2 -from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_4: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - - -@pytest.fixture(autouse=True) -def run_before_and_after_tests(): - # source: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test # noqa: E501 - - # setup (currently do nothing) - - # tests will run here - yield - - # teardown - # avoid dynamo cache limit issues - torch._dynamo.reset() - - -@pytest.fixture -def bitnet_tensor(): - input_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8) - return BitnetTensor.from_unpacked(input_tensor) - - -def test_copy(bitnet_tensor): - copied_tensor = bitnet_tensor.clone() - assert torch.equal(bitnet_tensor.elem, copied_tensor.elem) - - -def test_transpose(bitnet_tensor): - transposed_tensor = bitnet_tensor.t() - expected_tensor = unpack_uint2(bitnet_tensor.elem).t() - assert torch.equal(unpack_uint2(transposed_tensor.elem), expected_tensor) - - -def test_multiply(bitnet_tensor): - w_t = torch.randint(0, 15, (4, 16), dtype=torch.uint8) - w = BitnetTensor.from_unpacked(w_t) - torch.addmm(torch.Tensor([1]), bitnet_tensor, w) - - -@pytest.mark.parametrize( - "dtype", - [torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64], -) -def test_conversion(bitnet_tensor, dtype): - converted_tensor = bitnet_tensor.to(dtype) - expected_tensor = unpack_uint2(bitnet_tensor.elem).to(dtype) - assert torch.allclose(converted_tensor, expected_tensor, atol=1e-5) - - -def _apply_weight_only_uint2_quant(model): - def fn(mod): - mod.weight = torch.nn.Parameter( - BitnetTensor.from_float(mod.weight), requires_grad=False - ) - return mod - - _replace_with_custom_fn_if_matches_filter( - model, - lambda mod: fn(mod), - lambda mod, fqn: isinstance(mod, torch.nn.Linear), - ) - - -@pytest.mark.skipif( - TORCH_VERSION_AT_LEAST_2_5, reason="Regression introdued in nightlies" -) -@pytest.mark.parametrize("input_shape", [[2, 4], [5, 5, 5, 4], [1, 4, 4]]) -def test_uint2_quant(input_shape): - device = "cuda" if torch.cuda.is_available() else "cpu" - x = torch.randn(*input_shape).to(device) - m = nn.Sequential(nn.Linear(4, 16)).to(device) - y_ref = m(x) - _apply_weight_only_uint2_quant(m) - y_wo = m(x) - assert y_ref.shape == y_wo.shape - torch.compile(m, fullgraph=True)(x) - - -if __name__ == "__main__": - pytest.main(__file__) diff --git a/test/dtypes/test_uint2.py b/test/dtypes/test_uint2.py deleted file mode 100644 index f6faaea10d..0000000000 --- a/test/dtypes/test_uint2.py +++ /dev/null @@ -1,40 +0,0 @@ -import pytest -import torch - -from torchao.prototype.dtypes import UInt2Tensor -from torchao.prototype.dtypes.uint2 import unpack_uint2 -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 - -if not TORCH_VERSION_AT_LEAST_2_4: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - - -@pytest.fixture -def uint2_tensor(): - input_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8) - return UInt2Tensor(input_tensor) - - -def test_copy(uint2_tensor): - copied_tensor = uint2_tensor.clone() - assert torch.equal(uint2_tensor.elem, copied_tensor.elem) - - -def test_transpose(uint2_tensor): - transposed_tensor = uint2_tensor.t() - expected_tensor = unpack_uint2(uint2_tensor.elem).t() - assert torch.equal(unpack_uint2(transposed_tensor.elem), expected_tensor) - - -@pytest.mark.parametrize( - "dtype", - [torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64], -) -def test_conversion(uint2_tensor, dtype): - converted_tensor = uint2_tensor.to(dtype) - expected_tensor = unpack_uint2(uint2_tensor.elem).to(dtype) - assert torch.allclose(converted_tensor, expected_tensor, atol=1e-5) - - -if __name__ == "__main__": - pytest.main(__file__) diff --git a/test/prototype/test_bitpacking_gen.py b/test/prototype/test_bitpacking_gen.py deleted file mode 100644 index 288ac1e4fc..0000000000 --- a/test/prototype/test_bitpacking_gen.py +++ /dev/null @@ -1,41 +0,0 @@ -import pytest -import torch - -from torchao.prototype.dtypes.uintgen import ( - pack_uint2, - pack_uint3, - pack_uint4, - pack_uint5, - pack_uint6, - pack_uint7, - unpack_uint2, - unpack_uint3, - unpack_uint4, - unpack_uint5, - unpack_uint6, - unpack_uint7, -) - - -@pytest.mark.parametrize( - "pack_fn, unpack_fn, bit_count", - [ - (pack_uint2, unpack_uint2, 2), - (pack_uint3, unpack_uint3, 3), - (pack_uint4, unpack_uint4, 4), - (pack_uint5, unpack_uint5, 5), - (pack_uint6, unpack_uint6, 6), - (pack_uint7, unpack_uint7, 7), - ], -) -def test_uint_packing(pack_fn, unpack_fn, bit_count): - x = torch.arange(0, 256, dtype=torch.uint8) - y = pack_fn(x) - z = unpack_fn(y) - k = z.view(-1, 2**bit_count) - check = torch.arange(0, 2**bit_count, dtype=torch.uint8).repeat(k.size(0), 1) - assert torch.all(k == check), f"Failed for {bit_count}-bit packing" - - -if __name__ == "__main__": - pytest.main(__file__) diff --git a/torchao/prototype/dtypes/__init__.py b/torchao/prototype/dtypes/__init__.py deleted file mode 100644 index 9393737aff..0000000000 --- a/torchao/prototype/dtypes/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .bitnet import BitnetTensor -from .uint2 import UInt2Tensor - -__all__ = [ - "BitnetTensor", - "UInt2Tensor", -] diff --git a/torchao/prototype/dtypes/bitnet.py b/torchao/prototype/dtypes/bitnet.py deleted file mode 100644 index 72b444acd4..0000000000 --- a/torchao/prototype/dtypes/bitnet.py +++ /dev/null @@ -1,200 +0,0 @@ -import torch - -from torchao.prototype.dtypes.uint2 import UInt2Tensor, pack_uint2, unpack_uint2 - -BITNET_OPS_TABLE = {} - - -def implements(aten_ops): - def decorator(fn): - for op in aten_ops: - BITNET_OPS_TABLE[op] = fn - return fn - - return decorator - - -def _quantize_int2(x: torch.Tensor) -> torch.Tensor: - # Quantize the input tensor to int2 - quant = x.sign() + 1 - quant = BitnetTensor.from_unpacked(quant.to(torch.uint8)) - return quant - - -class BitnetTensor(UInt2Tensor): - def __new__(cls, input_tensor: torch.Tensor, **kwargs): - return super(BitnetTensor, cls).__new__(cls, input_tensor, **kwargs) - - def __init__(self, input_tensor: torch.Tensor, **kwargs): - super(BitnetTensor, self).__init__(input_tensor, **kwargs) - - @staticmethod - def __tensor_unflatten__(flattened, *meta): - # TODO - meta is not None, is it ok? - elem = flattened["elem"] - return BitnetTensor(elem) - - @classmethod - def from_unpacked(cls, unpacked: torch.Tensor) -> "BitnetTensor": - return cls(pack_uint2(unpacked)) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): - def allowed_subclasses(type): - return ( - issubclass(cls, type) - or issubclass(torch._subclasses.fake_tensor.FakeTensor, type) - or issubclass( - torch._subclasses.functional_tensor.FunctionalTensor, type - ) - ) - - if not all(allowed_subclasses(t) for t in types): - return NotImplemented("Bitnet, Up to the next one to handle") - - if func in BITNET_OPS_TABLE: - return BITNET_OPS_TABLE[func](func, args, kwargs) - raise NotImplementedError( - f"Bitnet dispatch: attempting to run {func}, this is not supported" - ) - - @classmethod - def from_float(cls, w: torch.Tensor): - w_intq = _quantize_int2(w) - w_int2 = w_intq.to(device=w.device) - return w_int2 - - def clone(self): - return BitnetTensor(self.elem.clone()) - - def copy_(self, src): - self.elem.copy_(src.elem) - return self - - def tolist(self): - data = unpack_uint2(self.elem).tolist() - return data - - def __repr__(self): - try: - data = unpack_uint2(self.elem).tolist() - except AssertionError: - data = f"Tensor of shape {self.shape} and dtype {self.elem.dtype}" - return f"BitnetTensor({data}, dtype={self.elem.dtype})" - - def to(self, *args, **kwargs): - if len(args) == 1 and isinstance(args[0], torch.dtype): - dtype = args[0] - if dtype == torch.int8: - return unpack_uint2(self.elem).view(self.shape).view(torch.int8) - elif dtype in ( - torch.float, - torch.float16, - torch.bfloat16, - torch.int16, - torch.int32, - torch.int64, - ): - return unpack_uint2(self.elem).to(torch.int8).to(dtype) - elif dtype == torch.uint8: - return unpack_uint2(self.elem).view(torch.uint8) - elif isinstance(self, BitnetTensor): - return self - if "device" in kwargs: - device = kwargs["device"] - return BitnetTensor(self.elem.to(device=device)) - - return super().to(*args, **kwargs) - - -@implements([torch.ops.aten.mm.default]) -def mm(func, args, kwargs): - x, weight = args - if isinstance(x, BitnetTensor): - x = unpack_uint2(x.elem).to(torch.float32) - if isinstance(weight, BitnetTensor): - weight = unpack_uint2(weight.elem).to(torch.float32) - y = torch.mm(x, weight) - return y - - -@implements([torch.ops.aten.addmm.default]) -def addmm(func, args, kwargs): - bias, x, weight = args - if isinstance(x, BitnetTensor): - x = unpack_uint2(x.elem).to(torch.float32) - if isinstance(weight, BitnetTensor): - weight = unpack_uint2(weight.elem).to(torch.float32) - if bias is not None: - bias = bias.to(torch.float32) - y = torch.addmm(bias, x, weight) - return y - - -@implements([torch.ops.aten.t.default]) -def t(func, args, kwargs): - (tensor,) = args - unpacked = unpack_uint2(tensor.elem).to(tensor.device) - transposed = unpacked.t() - return BitnetTensor(pack_uint2(transposed)) - - -@implements([torch.ops.aten.detach.default]) -def detach(func, args, kwargs): - (tensor,) = args - return tensor - - -@implements([torch.ops.aten.to.dtype]) -def to_dtype(func, args, kwargs): - (tensor, dtype) = args - if dtype == torch.int8: - return unpack_uint2(tensor.elem).view(torch.uint8) - 1 - elif dtype in ( - torch.float, - torch.float16, - torch.bfloat16, - torch.int16, - torch.int32, - torch.int64, - ): - return unpack_uint2(tensor.elem).to(torch.int8).to(dtype) - elif dtype == torch.uint8: - return unpack_uint2(tensor.elem).view(torch.uint8) - elif isinstance(tensor, BitnetTensor): - return tensor.elem - raise NotImplementedError(f"to {dtype} not supported") - - -@implements([torch.ops.aten._to_copy.default]) -def _to_copy(func, args, kwargs): - (tensor,) = args - dtype = kwargs["dtype"] - if dtype == torch.int8: - return BitnetTensor( - unpack_uint2(tensor).view(tensor.shape).view(torch.int8) - 1 - ) - elif dtype in ( - torch.float, - torch.float16, - torch.bfloat16, - torch.int16, - torch.int32, - torch.int64, - ): - return BitnetTensor(tensor.to(torch.int8).to(dtype)) - elif isinstance(tensor, BitnetTensor): - return BitnetTensor(tensor) - raise NotImplementedError(f"to {dtype} not supported") - - -@implements([torch.ops.aten.clone.default]) -def clone(func, args, kwargs): - (tensor,) = args - return tensor.clone() - - -@implements([torch.ops.aten.allclose.default]) -def allclose(func, args, kwargs): - (a, b) = args - return torch.allclose(a.elem, b.elem, **kwargs) diff --git a/torchao/prototype/dtypes/uint2.py b/torchao/prototype/dtypes/uint2.py deleted file mode 100644 index d54e541751..0000000000 --- a/torchao/prototype/dtypes/uint2.py +++ /dev/null @@ -1,280 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Dict, Tuple - -import torch -import torch._prims_common as utils - -from torchao.utils import fill_defaults - -UINT2_OPS_TABLE: Dict[Any, Any] = {} - - -def implements(aten_ops): - def decorator(fn): - for op in aten_ops: - UINT2_OPS_TABLE[op] = fn - return fn - - return decorator - - -def down_size(size): - assert size[-1] % 4 == 0, f"{size} last dim not divisible by 4" - return (*size[:-1], size[-1] // 4) - - -def up_size(size): - return (*size[:-1], size[-1] * 4) - - -def unpack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: - shape = uint8_data.shape - uint8_data = uint8_data.to(torch.uint8) - first_elements = (uint8_data >> 6) & 0b11 - second_elements = (uint8_data >> 4) & 0b11 - third_elements = (uint8_data >> 2) & 0b11 - fourth_elements = uint8_data & 0b11 - return torch.stack( - (first_elements, second_elements, third_elements, fourth_elements), dim=-1 - ).view(up_size(shape)) - - -def pack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: - shape = uint8_data.shape - assert shape[-1] % 4 == 0, f"{shape}, last dim not divisible by 4" - uint8_data = uint8_data.contiguous().view(-1) - packed_data = ( - uint8_data[::4] << 6 - | uint8_data[1::4] << 4 - | uint8_data[2::4] << 2 - | uint8_data[3::4] - ).view(down_size(shape)) - return packed_data - - -@dataclass -class SubclassTensorArgs: - original_shape: torch.Size - original_strides: Tuple - storage_offset: int - dtype: torch.dtype - device: torch.device - requires_grad: bool - - -class UInt2Tensor(torch.Tensor): - def __new__(cls, input_tensor: torch.Tensor): - assert input_tensor.dtype == torch.uint8 - tensor_meta = SubclassTensorArgs( - input_tensor.size(), - input_tensor.stride(), - input_tensor.storage_offset(), - cls, - input_tensor.device, - input_tensor.requires_grad, - ) - uint2i_tensor = torch.Tensor._make_wrapper_subclass( - cls, - up_size(tensor_meta.original_shape), - tensor_meta.original_strides, - tensor_meta.storage_offset, - dtype=torch.uint8, # Not sure if this is correct - device=tensor_meta.device, - requires_grad=tensor_meta.requires_grad, - ) - return uint2i_tensor - - def __init__(self, input_tensor: torch.Tensor, **kwargs): - self.elem = input_tensor - - @classmethod - def from_packed(cls, unpacked): - return UInt2Tensor(pack_uint2(unpacked)) - - def tolist(self): - return unpack_uint2(self.elem).tolist() - - def __tensor_flatten__(self): - return ["elem"], None - - @staticmethod - def __tensor_unflatten__(flattened, meta): - assert meta is None - elem = flattened["elem"] - return UInt2Tensor(elem) - - def __hash__(self): - return hash(self.elem) - - def __eq__(self, other): - return torch.equal(self.elem, other.elem) - - def __repr__(self): - data = unpack_uint2(self.elem).tolist() - return f"UInt2Tensor({data}, dtype=torch.uint2)" - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): - def allowed_subclasses(type): - return ( - issubclass(cls, type) - or issubclass(torch._subclasses.fake_tensor.FakeTensor, type) - or issubclass( - torch._subclasses.functional_tensor.FunctionalTensor, type - ) - ) - - if not all(allowed_subclasses(t) for t in types): - return NotImplemented("Up to the next one to handle") - - if func in UINT2_OPS_TABLE: - return UINT2_OPS_TABLE[func](func, args, kwargs) - raise NotImplementedError( - f"UINT2 dispatch: attempting to run {func}, this is not supported" - ) - - -@implements([torch.ops.aten.view.default]) -def uint2_view(func, args, kwargs): - tensor, size = args - size = utils.infer_size(size, tensor.numel()) - assert not kwargs - dsize = down_size(size) - reshaped_elem = tensor.elem.view(dsize) - return UInt2Tensor(reshaped_elem) - - -@implements([torch.ops.aten.view.dtype]) -def view_dtype(func, args, kwargs): - tensor, dtype = args - if dtype is torch.uint8: - return unpack_uint2(tensor.elem).to(torch.uint8) - raise NotImplementedError(f"view {dtype} not supported") - - -@implements([torch.ops.aten.clone.default]) -def clone(func, args, kwargs): - tensor = args[0] - return UInt2Tensor(tensor.elem.clone()) - - -@implements([torch.ops.aten._unsafe_view.default]) -def unsafe_view(func, args, kwargs): - tensor, size = args - size = utils.infer_size(size, tensor.numel()) - assert not kwargs - dsize = down_size(size) - reshaped_elem = tensor.elem.view(dsize) - return UInt2Tensor(reshaped_elem) - - -@implements([torch.ops.aten.unbind.int]) -def unbind(func, args, kwargs): - tensor, dim = fill_defaults(args, 2, [0]) - if dim != tensor.dim() - 1: - raise NotImplementedError(f"unbind dim={dim}") - else: - x = tensor.elem.to(torch.uint8).unbind(dim) - return x - - -@implements([torch.ops.aten._to_copy.default]) -def to_copy(func, args, kwargs): - (tensor,) = args - dtype = kwargs["dtype"] - if dtype == torch.uint8: - return unpack_uint2(tensor.elem).view(tensor.shape).view(torch.uint8) - elif dtype in ( - torch.float, - torch.float16, - torch.bfloat16, - torch.int16, - torch.int32, - torch.int64, - ): - return tensor.to(torch.uint8).to(dtype) - elif isinstance(tensor, UInt2Tensor): - return tensor - raise NotImplementedError(f"to_copy {dtype} not supported") - - -@implements([torch.ops.aten.select.int]) -def select(func, args, kwargs): - tensor, dim, index = args - if dim != tensor.dim() - 1: - selected_elem = tensor.elem.select(dim, index) - return UInt2Tensor(selected_elem) - else: - raise NotImplementedError(f"select dim={dim}") - - -@implements([torch.ops.aten.reshape.default]) -def reshape(func, args, kwargs): - tensor, size = args - size = utils.infer_size(size, tensor.numel()) - assert not kwargs - dsize = down_size(size) - reshaped_elem = tensor.elem.view(dsize) - return UInt2Tensor(reshaped_elem) - - -def slice_tensor(func, args, kwargs): - tensor, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) - if dim == tensor.dim() - 1: - if step != 1: - raise NotImplementedError(f"slice step={step}") - assert start % 4 == 0, start - assert end is None or end % 4 == 0, end - end = end if end is not None else tensor.shape[dim] - sliced_elem = tensor.elem[..., start // 4 : end // 4 : step] - return UInt2Tensor(sliced_elem) - else: - sliced_elem = tensor.elem[..., start:end:step] - return UInt2Tensor(sliced_elem) - - -@implements([torch.ops.aten.equal.default]) -def equal(func, args, kwargs): - tensor, other = args - return torch.equal(tensor.elem, other.elem) - - -@implements([torch.ops.aten.detach.default]) -def detach(func, args, kwargs): - (tensor,) = args - detached_elem = tensor.elem.detach() - return UInt2Tensor(detached_elem) - - -@implements([torch.ops.aten.to.dtype]) -def to_dtype(func, args, kwargs): - (tensor, dtype) = args - if dtype == torch.uint8: - return unpack_uint2(tensor.elem).view(torch.uint8) - elif dtype in ( - torch.float, - torch.float16, - torch.bfloat16, - torch.int16, - torch.int32, - torch.int64, - ): - return unpack_uint2(tensor.elem).to(torch.uint8).to(dtype) - elif isinstance(tensor, UInt2Tensor): - return tensor.elem - - raise NotImplementedError(f"to {dtype} not supported") - - -@implements([torch.ops.aten.t.default]) -def t(func, args, kwargs): - (tensor,) = args - unpacked = unpack_uint2(tensor.elem).to(tensor.device) - transposed = unpacked.t() - return UInt2Tensor(pack_uint2(transposed)) - - -@implements([torch.ops.aten.allclose.default]) -def allclose(func, args, kwargs): - tensor, other = args - return torch.allclose(tensor.elem, other.elem) diff --git a/torchao/prototype/dtypes/uintgen.py b/torchao/prototype/dtypes/uintgen.py deleted file mode 100644 index 192e4ad05a..0000000000 --- a/torchao/prototype/dtypes/uintgen.py +++ /dev/null @@ -1,385 +0,0 @@ -import torch - -""" -Contains generic functions to pack and unpack uintx (2-7) tensors into uint8 tensors. -""" - - -def down_size_uint2(size): - assert size[-1] % 4 == 0, f"{size} last dim not divisible by four" - return (*size[:-1], size[-1] // 4) - - -def up_size_uint2(size): - return (*size[:-1], size[-1] * 4) - - -def unpack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: - # since we are using uint8 we will decode 4 entries per byte - shape = uint8_data.shape - uint8_data = uint8_data.to(torch.uint8) - first_elements = (uint8_data >> 6) & 0b11 - second_elements = (uint8_data >> 4) & 0b11 - third_elements = (uint8_data >> 2) & 0b11 - fourth_elements = uint8_data & 0b11 - return torch.stack( - (first_elements, second_elements, third_elements, fourth_elements), dim=-1 - ).view(up_size_uint2(shape)) - - -def pack_uint2(uint8_data: torch.Tensor) -> torch.Tensor: - """pack lowest 2 bits of 2 uint8 -> 1 uint8""" - shape = uint8_data.shape - assert shape[-1] % 4 == 0 - uint8_data = uint8_data.contiguous().view(-1) - packed_data = ( - (uint8_data[::4] & 0b11) << 6 - | (uint8_data[1::4] & 0b11) << 4 - | (uint8_data[2::4] & 0b11) << 2 - | (uint8_data[3::4] & 0b11) - ).view(down_size_uint2(shape)) - return packed_data - - -def down_size_uint3(size): - assert size[-1] % 8 == 0, f"{size} last dim not divisible by eight" - return (*size[:-1], size[-1] // 8 * 3) - - -def up_size_uint3(size): - assert size[-1] % 3 == 0, f"{size} last dim not divisible by three" - return (*size[:-1], size[-1] // 3 * 8) - - -def unpack_uint3(uint8_data: torch.Tensor) -> torch.Tensor: - """ - 3 -> 8 - 01234567|01234567|01234567 - AAABBBCC|CDDDEEEF|FFGGGHHH - """ - shape = uint8_data.shape - uint8_data = uint8_data.to(torch.uint8) - - return torch.stack( - ( - (uint8_data[::3] >> 5) & 0b111, - (uint8_data[::3] >> 2) & 0b111, - (uint8_data[::3] & 0b11) << 1 | (uint8_data[1::3] >> 7) & 0b1, - (uint8_data[1::3] >> 4) & 0b111, - (uint8_data[1::3] >> 1) & 0b111, - (uint8_data[1::3] & 0b1) << 2 | (uint8_data[2::3] >> 6) & 0b11, - (uint8_data[2::3] >> 3) & 0b111, - uint8_data[2::3] & 0b111, - ), - dim=-1, - ).view(up_size_uint3(shape)) - - -def pack_uint3(uint8_data: torch.Tensor) -> torch.Tensor: - """ - 8 -> 3 - 01234567|01234567|01234567 - AAABBBCC|CDDDEEEF|FFGGGHHH - """ - - shape = uint8_data.shape - assert shape[-1] % 8 == 0 - uint8_data = uint8_data.contiguous().view(-1) - - packed_data = torch.stack( - ( - ( - (uint8_data[::8] & 0b111) << 5 - | (uint8_data[1::8] & 0b111) << 2 - | (uint8_data[2::8] & 0b111) >> 1 - ), - ( - (uint8_data[2::8] & 0b1) << 7 - | (uint8_data[3::8] & 0b111) << 4 - | (uint8_data[4::8] & 0b111) << 1 - | ((uint8_data[5::8] >> 2) & 1) - ), - ( - (uint8_data[5::8] & 0b11) << 6 - | (uint8_data[6::8] & 0b111) << 3 - | (uint8_data[7::8] & 0b111) - ), - ), - dim=-1, - ).view(down_size_uint3(shape)) - - return packed_data - - -def down_size_uint4(size): - assert size[-1] % 2 == 0, f"{size} last dim not divisible by two" - return (*size[:-1], size[-1] // 2) - - -def up_size_uint4(size): - return (*size[:-1], size[-1] * 2) - - -def unpack_uint4(uint8_data: torch.Tensor) -> torch.Tensor: - shape = uint8_data.shape - uint8_data = uint8_data.to(torch.uint8) - first_elements = (uint8_data >> 4) & 0b1111 - second_elements = uint8_data & 0b1111 - return torch.stack((first_elements, second_elements), dim=-1).view( - up_size_uint4(shape) - ) - - -def pack_uint4(uint8_data: torch.Tensor) -> torch.Tensor: - shape = uint8_data.shape - assert shape[-1] % 2 == 0 - uint8_data = uint8_data.contiguous().view(-1) - packed_data = (uint8_data[::2] << 4 | (uint8_data[1::2] & 0b1111)).view( - down_size_uint4(shape) - ) - return packed_data - - -def down_size_uint5(size): - assert size[-1] % 8 == 0, f"{size} last dim not divisible by 8" - return (*size[:-1], size[-1] // 8 * 5) - - -def up_size_uint5(size): - assert size[-1] % 5 == 0, f"{size} last dim not divisible by 5" - return (*size[:-1], size[-1] // 5 * 8) - - -def pack_uint5(uint8_data: torch.Tensor) -> torch.Tensor: - """Pack the 5 lowest bits of 8 input bytes into 5 bytes - - 8 -> 5 - 01234567|01234567|01234567|01234567|01234567 - AAAAABBB|BBCCCCCD|DDDDEEEE|EFFFFFGG|GGGHHHHH - - The packing pattern: - - First byte: (A0 A1 A2 A3 A4 B0 B1 B2) - - Second byte: (B3 B4 C0 C1 C2 C3 C4 D0) - - Third byte: (D1 D2 D3 D4 E0 E1 E2 E3) - - Fourth byte: (E4 F0 F1 F2 F3 F4 G0 G1) - - Fifth byte: (G2 G3 G4 H0 H1 H2 H3 H4) - """ - shape = uint8_data.shape - assert ( - shape[-1] % 8 == 0 - ), f"Input last dimension should be divisible by 8, but got {shape[-1]}" - - uint8_data = uint8_data.contiguous().view(-1, 8) - - packed_data = torch.stack( - ( - ((uint8_data[:, 0] & 0b00011111) << 3) - | ((uint8_data[:, 1] & 0b00011100) >> 2), - ((uint8_data[:, 1] & 0b00000011) << 6) - | ((uint8_data[:, 2] & 0b00011111) << 1) - | ((uint8_data[:, 3] & 0b10000) >> 4), - ((uint8_data[:, 3] & 0b00001111) << 4) - | ((uint8_data[:, 4] & 0b00011110) >> 1), - ((uint8_data[:, 4] & 0b00000001) << 7) - | ((uint8_data[:, 5] & 0b00011111) << 2) - | ((uint8_data[:, 6] & 0b0011000) >> 3), - ((uint8_data[:, 6] & 0b00000111) << 5) | (uint8_data[:, 7] & 0b00011111), - ), - dim=-1, - ).view(down_size_uint5(shape)) - - return packed_data - - -def unpack_uint5(packed_data: torch.Tensor) -> torch.Tensor: - """Unpack the 5 bytes into the 5 lowest bits of 8 bytes - 01234567|01234567|01234567|01234567|01234567 - AAAAABBB|BBCCCCCD|DDDDEEEE|EFFFFFGG|GGGHHHHH - """ - shape = packed_data.shape - assert ( - shape[-1] % 5 == 0 - ), f"Input last dimension should be divisible by 5, but got {shape[-1]}" - - packed_data = packed_data.contiguous().view(-1, 5) - - unpacked_data = torch.stack( - ( - ((packed_data[:, 0] >> 3) & 0b00011111), - ((packed_data[:, 0] & 0b00000111) << 2) - | ((packed_data[:, 1] >> 6) & 0b00000011), - ((packed_data[:, 1] >> 1) & 0b00011111), - ((packed_data[:, 1] & 0b00000001) << 4) - | ((packed_data[:, 2] >> 4) & 0b00001111), - ((packed_data[:, 2] & 0b00001111) << 1) - | ((packed_data[:, 3] >> 7) & 0b00000001), - ((packed_data[:, 3] >> 2) & 0b00011111), - ((packed_data[:, 3] & 0b00000011) << 3) - | ((packed_data[:, 4] >> 5) & 0b00000111), - packed_data[:, 4] & 0b00011111, - ), - dim=-1, - ).view(up_size_uint5(shape)) - - return unpacked_data - - -def down_size_uint6(size): - assert size[-1] % 4 == 0, f"{size} last dim not divisible by four" - return (*size[:-1], size[-1] // 4 * 3) - - -def up_size_uint6(size): - assert size[-1] % 3 == 0, f"{size} last dim not divisible by three" - return (*size[:-1], size[-1] // 3 * 4) - - -def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor: - """Pack the 6 lowest bits of 4 input bytes into 3 bytes - - 4 -> 3 - 01234567|01234567|01234567 - AAAAAABB|BBBBCCCC|CCDDDDDD - - The packing pattern: - - First byte: (A0 A1 A2 A3 A4 A5 B0 B1) - - Second byte: (B2 B3 B4 B5 C0 C1 C2 C3) - - Third byte: (C4 C5 D0 D1 D2 D3 D4 D5) - """ - shape = uint8_data.shape - assert ( - shape[-1] % 4 == 0 - ), f"Input last dimension should be divisible by 4, but got {shape[-1]}" - - uint8_data = uint8_data.contiguous().view(-1, 4) - - packed_data = torch.stack( - ( - ((uint8_data[:, 0] & 0b00111111) << 2) - | ((uint8_data[:, 1] >> 4) & 0b00000011), - ((uint8_data[:, 1] & 0b00001111) << 4) - | ((uint8_data[:, 2] >> 2) & 0b00001111), - ((uint8_data[:, 2] & 0b00000011) << 6) | (uint8_data[:, 3] & 0b00111111), - ), - dim=-1, - ).view(down_size_uint6(shape)) - - return packed_data - - -def unpack_uint6(packed_data: torch.Tensor) -> torch.Tensor: - """Unpack the 3 bytes into the 6 lowest bits of 4 outputs - 01234567|01234567|01234567 - AAAAAABB|BBBBCCCC|CCDDDDDD - """ - shape = packed_data.shape - assert ( - shape[-1] % 3 == 0 - ), f"Input last dimension should be divisible by 3, but got {shape[-1]}" - - packed_data = packed_data.contiguous().view(-1, 3) - - unpacked_data = torch.stack( - ( - (packed_data[:, 0] >> 2) & 0b00111111, - ((packed_data[:, 0] & 0b00000011) << 4) - | ((packed_data[:, 1] >> 4) & 0b00001111), - ((packed_data[:, 1] & 0b00001111) << 2) - | ((packed_data[:, 2] >> 6) & 0b00000011), - packed_data[:, 2] & 0b00111111, - ), - dim=-1, - ).view(up_size_uint6(shape)) - - return unpacked_data - - -def down_size_uint7(size): - assert size[-1] % 8 == 0, f"{size} last dim not divisible by 8" - return (*size[:-1], size[-1] // 8 * 7) - - -def up_size_uint7(size): - assert size[-1] % 7 == 0, f"{size} last dim not divisible by 7" - return (*size[:-1], size[-1] // 7 * 8) - - -def pack_uint7(uint8_data: torch.Tensor) -> torch.Tensor: - """Pack the 7 lowest bits of 8 input bytes into 7 bytes - - 8 -> 7 - 01234567|01234567|01234567|01234567|01234567|01234567|01234567 - AAAAAAAB|BBBBBBCC|CCCCCDDD|DDDDEEEE|EEEFFFFF|FFGGGGGG|GHHHHHHH - - The packing pattern: - - First byte: (A0 A1 A2 A3 A4 A5 A6 B0) - - Second byte: (B1 B2 B3 B4 B5 B6 C0 C1) - - Third byte: (C2 C3 C4 C5 C6 D0 D1 D2) - - Fourth byte: (D3 D4 D5 D6 E0 E1 E2 E3) - - Fifth byte: (E4 E5 E6 F0 F1 F2 F3 F4) - - Sixth byte: (F5 F6 G0 G1 G2 G3 G4 G5) - - Seventh byte:(G6 H0 H1 H2 H3 H4 H5 H6) - """ - shape = uint8_data.shape - assert ( - shape[-1] % 8 == 0 - ), f"Input last dimension should be divisible by 8, but got {shape[-1]}" - - uint8_data = uint8_data.contiguous().view(-1, 8) - - packed_data = torch.stack( - ( - ((uint8_data[:, 0] & 0b01111111) << 1) - | ((uint8_data[:, 1] >> 6) & 0b00000001), - ((uint8_data[:, 1] & 0b00111111) << 2) - | ((uint8_data[:, 2] >> 5) & 0b00000011), - ((uint8_data[:, 2] & 0b00011111) << 3) - | ((uint8_data[:, 3] >> 4) & 0b00000111), - ((uint8_data[:, 3] & 0b00001111) << 4) - | ((uint8_data[:, 4] >> 3) & 0b00001111), - ((uint8_data[:, 4] & 0b00000111) << 5) - | ((uint8_data[:, 5] >> 2) & 0b00011111), - ((uint8_data[:, 5] & 0b00000011) << 6) - | ((uint8_data[:, 6] >> 1) & 0b00111111), - ((uint8_data[:, 6] & 0b00000001) << 7) - | ((uint8_data[:, 7] >> 0) & 0b01111111), - ), - dim=-1, - ).view(down_size_uint7(shape)) - - return packed_data - - -def unpack_uint7(packed_data: torch.Tensor) -> torch.Tensor: - """Unpack the 7 bytes into the 7 lowest bits of 8 bytes - 01234567|01234567|01234567|01234567|01234567|01234567|01234567 - AAAAAAAB|BBBBBBCC|CCCCCDDD|DDDDEEEE|EEEFFFFF|FFGGGGGG|GHHHHHHH - """ - shape = packed_data.shape - assert ( - shape[-1] % 7 == 0 - ), f"Input last dimension should be divisible by 7, but got {shape[-1]}" - - packed_data = packed_data.contiguous().view(-1, 7) - - unpacked_data = torch.stack( - ( - (packed_data[:, 0] >> 1) & 0b01111111, - ((packed_data[:, 0] & 0b00000001) << 6) - | ((packed_data[:, 1] >> 2) & 0b01111111), - ((packed_data[:, 1] & 0b00000011) << 5) - | ((packed_data[:, 2] >> 3) & 0b01111111), - ((packed_data[:, 2] & 0b00000111) << 4) - | ((packed_data[:, 3] >> 4) & 0b01111111), - ((packed_data[:, 3] & 0b00001111) << 3) - | ((packed_data[:, 4] >> 5) & 0b01111111), - ((packed_data[:, 4] & 0b00011111) << 2) - | ((packed_data[:, 5] >> 6) & 0b01111111), - ((packed_data[:, 5] & 0b00111111) << 1) - | ((packed_data[:, 6] >> 7) & 0b01111111), - packed_data[:, 6] & 0b01111111, - ), - dim=-1, - ).view(up_size_uint7(shape)) - - return unpacked_data