Skip to content

[WIP] Int4Tensor refactor to implements pattern #458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion test/dtypes/test_uint4.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import torch
from torchao.dtypes.uint4 import (
UInt4Tensor,
PerChannelSymmetricWeightUInt4Tensor,
)
from torchao.dtypes import (
PerChannelSymmetricWeightUInt4Tensor
)
import unittest
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
Expand Down
2 changes: 2 additions & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from .nf4tensor import NF4Tensor, to_nf4
# from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor
from .uint4 import UInt4Tensor

from .perchannel_symmetricweight import PerChannelSymmetricWeightUInt4Tensor
from .affine_quantized_tensor import (
AffineQuantizedTensor,
to_affine_quantized_intx,
Expand Down
133 changes: 133 additions & 0 deletions torchao/dtypes/channel_symmetricweight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import torch
from .uint4i import pack_uint4, unpack_uint4
from .uint4i import UInt4Tensor
from typing import Dict, Any

SYMMETRIC_WEIGHT_OPS_TABLE: Dict[Any, Any] = {}

def implements(aten_ops):
def decorator(fn):
for op in aten_ops:
SYMMETRIC_WEIGHT_OPS_TABLE[op] = fn
return fn
return decorator

def _dynamically_quantize_per_channel_int4(x, quant_min, quant_max, target_dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also I feel this can probably be covered by AffineQuantizedTensor:

Affine quantized tensor subclass. Affine quantization means we quantize the floating point tensor with an affine transformation:
, and the previous dequantize_per_channel is calling our unified quant_primitive ops:
def dequantize_per_channel(int_repr, scales, zero_points, out_dtype=torch.float32):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Working on this

# assumes symmetric quantization
# assumes axis == 0
# assumes dense memory format
# TODO(future): relax ^ as needed

# default setup for affine quantization of activations
eps = torch.finfo(torch.float32).eps

# get min and max
min_val, max_val = torch.aminmax(x, dim=1)

# calculate scale and zero point based on min and max
# reference: https://fburl.com/code/srbiybme
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
device = min_val_neg.device

# reference: https://fburl.com/code/4wll53rk
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
# ensure scale is the same dtype as the original tensor
scale = torch.clamp(scale, min=eps).to(x.dtype)
zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)

# quantize based on qmin/qmax/scale/zp
# reference: torch/ao/quantization/fx/_decomposed.py?lines=63
x_div = x.transpose(0, 1) / scale
x_round = torch.round(x_div)
x_zp = x_round + zero_point
x_zp = x_zp.transpose(0, 1)
quant = torch.clamp(x_zp, quant_min, quant_max)

if target_dtype == torch.uint4:
# TODO: simplify (maybe implement to)
quant = PerChannelSymmetricWeightUInt4Tensor.from_unpacked(
quant.to(torch.uint8), scale
)
else:
quant = quant.to(target_dtype)

return quant, scale, zero_point

class PerChannelSymmetricWeightUInt4Tensor(UInt4Tensor):
@staticmethod
def __new__(cls, elem, scales, **kwargs):
return super(UInt4Tensor, cls).__new__(cls, elem, **kwargs)

def __init__(self, elem, scales, **kwargs):
super(UInt4Tensor, self).__init__(elem, **kwargs)

self.scales = scales

def __tensor_flatten__(self):
return ["elem", "scales"], None

@staticmethod
def __tensor_unflatten__(flattened, meta, outer_size, outer_stride):
assert meta is None
elem = flattened["elem"]
scales = flattened["scales"]
return PerChannelSymmetricWeightUInt4Tensor(elem, scales)

@classmethod

# inconsistently.

def from_unpacked(cls, unpacked, scales):
return cls(pack_uint4(unpacked), scales)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
def allowed_subclasses(type):
return (
issubclass(cls, type) or
issubclass(torch._subclasses.fake_tensor.FakeTensor, type) or
issubclass(torch._subclasses.functional_tensor.FunctionalTensor, type)
)

if not all(allowed_subclasses(t) for t in types):
return NotImplemented("Up to the next one to handle")

if func in SYMMETRIC_WEIGHT_OPS_TABLE:
return SYMMETRIC_WEIGHT_OPS_TABLE[func](func, args, kwargs)
raise NotImplementedError(f"UINT4 dispatch: attempting to run {func}, this is not supported")


@classmethod
def from_float(cls, w_fp32):
w_int4, scales, _zp = _dynamically_quantize_per_channel_int4(
w_fp32, 0, 15, torch.uint4
)
w_int4 = w_int4.to(device=w_fp32.device)
return w_int4

@implements([torch.ops.aten.addmm.default])
def addmm(func, args, kwargs):
bias, x, weight = args
x_view = x.view(-1, x.shape[-1])
y = torch.mm(x_view, weight.to(torch.uint8).to(x.dtype)) * weight.scales
y = y.reshape(*x.shape[:-1], -1)
if bias is not None:
y += bias
return y

@implements([torch.ops.aten.t.default])
def t(func, args, kwargs):
# TODO: add proper support for transpose
(tensor,) = args
unpacked = unpack_uint4(tensor.elem)
transposed = torch.ops.aten.t.default(unpacked)
return PerChannelSymmetricWeightUInt4Tensor.from_unpacked(
transposed, tensor.scales
)

@implements([torch.ops.aten.detach.default])
def detach(func, args, kwargs):
(tensor,) = args
return tensor
121 changes: 121 additions & 0 deletions torchao/dtypes/perchannel_symmetricweight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import torch
from torchao.dtypes.uint4 import pack_uint4, unpack_uint4
from torchao.dtypes import UInt4Tensor
from typing import Dict, Any
from torchao.dtypes.utils import _implements
from torchao.dtypes.utils import _dispatch__torch_function__, _dispatch__torch_dispatch__

SYMMETRIC_WEIGHT_OPS_TABLE: Dict[Any, Any] = {}

def _dynamically_quantize_per_channel_int4(x, quant_min, quant_max, target_dtype):
# assumes symmetric quantization
# assumes axis == 0
# assumes dense memory format
# TODO(future): relax ^ as needed

# default setup for affine quantization of activations
eps = torch.finfo(torch.float32).eps

# get min and max
min_val, max_val = torch.aminmax(x, dim=1)

# calculate scale and zero point based on min and max
# reference: https://fburl.com/code/srbiybme
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
device = min_val_neg.device

# reference: https://fburl.com/code/4wll53rk
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
# ensure scale is the same dtype as the original tensor
scale = torch.clamp(scale, min=eps).to(x.dtype)
zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)

# quantize based on qmin/qmax/scale/zp
# reference: torch/ao/quantization/fx/_decomposed.py?lines=63
x_div = x.transpose(0, 1) / scale
x_round = torch.round(x_div)
x_zp = x_round + zero_point
x_zp = x_zp.transpose(0, 1)
quant = torch.clamp(x_zp, quant_min, quant_max)

if target_dtype == torch.uint4:
# TODO: simplify (maybe implement to)
quant = PerChannelSymmetricWeightUInt4Tensor.from_unpacked(
quant.to(torch.uint8), scale
)
else:
quant = quant.to(target_dtype)

return quant, scale, zero_point

class PerChannelSymmetricWeightUInt4Tensor(UInt4Tensor):
@staticmethod
def __new__(cls, elem, scales, **kwargs):
return super().__new__(cls, elem, **kwargs)

implements = classmethod(_implements)
def __init__(self, elem, scales, **kwargs):
super().__init__(elem, **kwargs)

self.scales = scales

def __tensor_flatten__(self):
return ["elem", "scales"], None

@staticmethod
def __tensor_unflatten__(flattened, meta, outer_size, outer_stride):
assert meta is None
elem = flattened["elem"]
scales = flattened["scales"]
return PerChannelSymmetricWeightUInt4Tensor(elem, scales)

@classmethod # inconsistently.
def from_unpacked(cls, unpacked, scales):
return cls(pack_uint4(unpacked), scales)

__torch_function__ = classmethod(_dispatch__torch_function__)

__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)

@classmethod
def from_float(cls, w_fp32):
w_int4, scales, _zp = _dynamically_quantize_per_channel_int4(
w_fp32, 0, 15, torch.uint4
)
w_int4 = w_int4.to(device=w_fp32.device)
return w_int4

@implements([torch.ops.aten.addmm.default])
def _(func, args, kwargs):
bias, x, weight = args
x_view = x.view(-1, x.shape[-1])
y = torch.mm(x_view, weight.to(torch.uint8).to(x.dtype)) * weight.scales
y = y.reshape(*x.shape[:-1], -1)
if bias is not None:
y += bias
return y

@implements([torch.ops.aten.t.default])
def _(func, args, kwargs):
# TODO: add proper support for transpose
(tensor,) = args
unpacked = unpack_uint4(tensor.elem)
transposed = torch.ops.aten.t.default(unpacked)
return PerChannelSymmetricWeightUInt4Tensor.from_unpacked(
transposed, tensor.scales
)

@implements([torch.ops.aten.detach.default])
def _(func, args, kwargs):
(tensor,) = args
return

if __name__ == "__main__":
# test
x = torch.randn(2, 3, 4)
w = torch.randn(5, 4)
b = torch.randn(5)
y = PerChannelSymmetricWeightUInt4Tensor.from_float(w)
# print(y)
Loading