Skip to content

Commit a2a2395

Browse files
committed
[StaticQuant] add a linear observer class and test
stack-info: PR: #807, branch: drisspg/stack/8
1 parent ca125ea commit a2a2395

File tree

5 files changed

+354
-19
lines changed

5 files changed

+354
-19
lines changed

ruff.toml

+2
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,6 @@ include = [
88
"torchao/dtypes/nf4tensor.py",
99
"test/dtypes/test_nf4.py",
1010
"torchao/float8/float8_tensor.py",
11+
"torchao/quantization/linear_observer_tensor.py",
12+
"test/quantization/test_observer.py",
1113
]

test/quantization/test_observer.py

+98-4
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,31 @@
11
import torch
2+
import torch.nn as nn
23
from torch.testing._internal.common_utils import TestCase
34
from torchao.quantization.observer import (
45
AffineQuantizedMinMaxObserver,
56
PerTensor,
67
PerAxis,
78
)
9+
from torchao.quantization import quantize_
810
from torchao.quantization.quant_primitives import (
911
MappingType,
1012
)
13+
from torchao.quantization.linear_observer_tensor import (
14+
insert_observers,
15+
)
1116
import unittest
17+
1218
# NOTE: we can copy paste these here if we decide to deprecate them in torch.ao
1319
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
1420

21+
1522
class TestQuantFlow(TestCase):
1623
def _test_obs_helper(self, obs1, obs2):
17-
example_inputs = [torch.randn(10, 2048), torch.randn(10, 2048), torch.randn(10, 2048)]
24+
example_inputs = [
25+
torch.randn(10, 2048),
26+
torch.randn(10, 2048),
27+
torch.randn(10, 2048),
28+
]
1829
for example_input in example_inputs:
1930
obs1(example_input)
2031
obs2(example_input)
@@ -25,15 +36,98 @@ def _test_obs_helper(self, obs1, obs2):
2536
self.assertTrue(torch.allclose(zero_point1, zero_point2))
2637

2738
def test_min_max_per_tensor_affine(self):
28-
obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int)
39+
obs = AffineQuantizedMinMaxObserver(
40+
MappingType.ASYMMETRIC,
41+
torch.uint8,
42+
granularity_type=PerTensor(),
43+
eps=torch.finfo(torch.float32).eps,
44+
scale_dtype=torch.float,
45+
zero_point_dtype=torch.int,
46+
)
2947
ref_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine)
3048
self._test_obs_helper(obs, ref_obs)
3149

3250
def test_min_max_per_channel_affine(self):
33-
obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int)
34-
ref_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine)
51+
obs = AffineQuantizedMinMaxObserver(
52+
MappingType.ASYMMETRIC,
53+
torch.uint8,
54+
granularity_type=PerAxis(axis=0),
55+
eps=torch.finfo(torch.float32).eps,
56+
scale_dtype=torch.float,
57+
zero_point_dtype=torch.int,
58+
)
59+
ref_obs = PerChannelMinMaxObserver(
60+
dtype=torch.uint8, qscheme=torch.per_channel_affine
61+
)
3562
self._test_obs_helper(obs, ref_obs)
3663

3764

65+
class TestLinearObserver(TestCase):
66+
def test_linear_observer_tensor(self):
67+
# Create a simple linear layer
68+
in_features, out_features = 10, 5
69+
linear = nn.Linear(in_features, out_features)
70+
71+
# Create observers
72+
input_observer = AffineQuantizedMinMaxObserver(
73+
MappingType.SYMMETRIC,
74+
torch.float8_e4m3fn,
75+
granularity_type=PerTensor(),
76+
eps=torch.finfo(torch.float32).eps,
77+
scale_dtype=torch.float,
78+
zero_point_dtype=torch.int,
79+
zero_point_domain=None,
80+
)
81+
weight_observer = AffineQuantizedMinMaxObserver(
82+
MappingType.SYMMETRIC,
83+
torch.float8_e4m3fn,
84+
granularity_type=PerTensor(),
85+
eps=torch.finfo(torch.float32).eps,
86+
scale_dtype=torch.float,
87+
zero_point_dtype=torch.int,
88+
zero_point_domain=None,
89+
)
90+
91+
# Wrap the weight with LinearObserverTensor
92+
quantize_(linear, insert_observers(input_observer, weight_observer))
93+
94+
# Create some example inputs
95+
example_inputs = [torch.randn(5, in_features) for _ in range(3)]
96+
max_val = 42.1234
97+
min_val = -39.760
98+
big_tensor = torch.full((6, in_features), max_val)
99+
small_tensor = torch.full((40, in_features), min_val)
100+
example_inputs.extend([big_tensor, small_tensor])
101+
102+
# Run forward passes
103+
for example_input in example_inputs:
104+
_ = linear(example_input)
105+
106+
input_observer = linear.weight.input_observer
107+
weight_observer = linear.weight.weight_observer
108+
109+
# Check that the observers have recorded statistics
110+
assert input_observer.min_val == min_val
111+
assert input_observer.max_val == max_val
112+
113+
# Calculate qparams and ensure they're not None
114+
input_scale, input_zero_point = input_observer.calculate_qparams()
115+
weight_scale, weight_zero_point = weight_observer.calculate_qparams()
116+
117+
max_fp8 = torch.finfo(torch.float8_e4m3fn).max
118+
self.assertEqual(
119+
input_scale.item(),
120+
max_val / max_fp8,
121+
)
122+
self.assertIsNotNone(input_zero_point)
123+
torch.testing.assert_close(
124+
weight_scale,
125+
torch.max(linear.weight.original_weight_tensor) / max_fp8,
126+
atol=5e-5,
127+
rtol=0.0,
128+
)
129+
self.assertIsNotNone(weight_zero_point)
130+
131+
38132
if __name__ == "__main__":
39133
unittest.main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
import torch
2+
import torch.nn as nn
3+
from torchao.utils import (
4+
_implements,
5+
_dispatch__torch_function__,
6+
_dispatch__torch_dispatch__,
7+
)
8+
from typing import Callable, Optional, Dict
9+
from torch.utils._python_dispatch import return_and_correct_aliasing
10+
from torchao.utils import (
11+
TorchAOBaseTensor,
12+
TORCH_VERSION_AT_LEAST_2_5,
13+
)
14+
15+
from torchao.quantization.observer import AffineQuantizedObserverBase
16+
17+
__all__ = [
18+
"LinearObserverTensor",
19+
"to_linear_observer_tensor",
20+
"insert_observers",
21+
]
22+
23+
aten = torch.ops.aten
24+
Tensor = torch.Tensor
25+
26+
27+
class LinearObserverTensor(TorchAOBaseTensor):
28+
"""
29+
This subclass of Tensor is used in conjuction with a static calibration flow.
30+
The flow is broken up into 3 parts;
31+
1. Insert the LinearObserverTensor subclass into the model's nn.Linear layers
32+
2. Run the model with a calibration dataset, the observer will record the min/max of the input and weight
33+
3. quantize_ the model to static using the statistics recorded by the observer
34+
35+
This subclass wraps the original weight tensor on the nn.Linear layer. When forward is called, the observer
36+
will first calculat statistics on BOTH the input and weight, and then run the linear op.
37+
"""
38+
39+
original_weight_tensor: torch.Tensor
40+
input_observer: Optional[AffineQuantizedObserverBase]
41+
weight_observer: Optional[AffineQuantizedObserverBase]
42+
43+
def __new__(
44+
cls,
45+
original_weight_tensor: torch.Tensor,
46+
input_observer: Optional[AffineQuantizedObserverBase] = None,
47+
weight_observer: Optional[AffineQuantizedObserverBase] = None,
48+
):
49+
kwargs = {}
50+
dtype = original_weight_tensor.dtype
51+
kwargs["dtype"] = dtype
52+
kwargs["requires_grad"] = False
53+
kwargs["device"] = original_weight_tensor.device
54+
shape = original_weight_tensor.shape
55+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
56+
57+
def __init__(
58+
self,
59+
original_weight_tensor: torch.Tensor,
60+
input_observer: Optional[AffineQuantizedObserverBase] = None,
61+
weight_observer: Optional[AffineQuantizedObserverBase] = None,
62+
):
63+
self.original_weight_tensor = original_weight_tensor
64+
self.input_observer = input_observer
65+
self.weight_observer = weight_observer
66+
67+
def __repr__(self):
68+
return (
69+
f"LinearObserverTensor(\n"
70+
f"original_weight={self.original_weight_tensor}\n"
71+
f"input_observer={self.input_observer.__class__.__name__ if self.input_observer else None}\n"
72+
f"weight_observer={self.weight_observer.__class__.__name__ if self.weight_observer else None}\n)"
73+
)
74+
75+
def __tensor_flatten__(self):
76+
return ["original_weight_tensor"], [self.input_observer, self.weight_observer]
77+
78+
@classmethod
79+
def __tensor_unflatten__(
80+
cls,
81+
tensor_data_dict: Dict[str, Tensor],
82+
tensor_attributes,
83+
outer_size,
84+
outer_stride,
85+
):
86+
original_weight_tensor = tensor_data_dict["original_weight_tensor"]
87+
(input_observer, weight_observer) = tensor_attributes
88+
return cls(original_weight_tensor, input_observer, weight_observer)
89+
90+
@classmethod
91+
def from_float(
92+
cls,
93+
original_weight_tensor: Tensor,
94+
input_observer: Optional[AffineQuantizedObserverBase] = None,
95+
weight_observer: Optional[AffineQuantizedObserverBase] = None,
96+
):
97+
return cls(original_weight_tensor, input_observer, weight_observer)
98+
99+
def _apply_fn_to_data(self, fn: Callable):
100+
"""Applies a fn to the tensor component of the LinearObserverTensor"""
101+
return self.__class__(
102+
fn(self.original_weight_tensor),
103+
self.input_observer,
104+
self.weight_observer,
105+
)
106+
107+
def to(self, *args, **kwargs):
108+
kwargs = self._get_to_kwargs(*args, **kwargs)
109+
return self._apply_fn_to_data(lambda x: x.to(**kwargs))
110+
111+
implements = classmethod(_implements)
112+
__torch_function__ = classmethod(_dispatch__torch_function__)
113+
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
114+
115+
116+
implements = LinearObserverTensor.implements
117+
118+
119+
@implements(torch.nn.functional.linear)
120+
def _(func, types, args, kwargs):
121+
input_tensor, weight_tensor, bias = (
122+
args[0],
123+
args[1],
124+
args[2] if len(args) > 2 else None,
125+
)
126+
if weight_tensor.input_observer is not None:
127+
input_tensor = weight_tensor.input_observer(input_tensor)
128+
if weight_tensor.weight_observer is not None:
129+
weight_tensor = weight_tensor.weight_observer(
130+
weight_tensor.original_weight_tensor
131+
)
132+
133+
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
134+
135+
136+
@implements(aten.detach.default)
137+
def _(func, types, args, kwargs):
138+
return return_and_correct_aliasing(
139+
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
140+
)
141+
142+
143+
@implements(aten.clone.default)
144+
def _(func, types, args, kwargs):
145+
return return_and_correct_aliasing(
146+
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
147+
)
148+
149+
150+
@implements(aten._to_copy.default)
151+
def _(func, types, args, kwargs):
152+
return return_and_correct_aliasing(
153+
func,
154+
args,
155+
kwargs,
156+
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
157+
)
158+
159+
160+
to_linear_observer_tensor = LinearObserverTensor.from_float
161+
162+
if TORCH_VERSION_AT_LEAST_2_5:
163+
# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True`
164+
torch.serialization.add_safe_globals([LinearObserverTensor])
165+
166+
167+
def insert_observers(
168+
input_observer: Optional[AffineQuantizedObserverBase],
169+
weight_observer: Optional[AffineQuantizedObserverBase],
170+
) -> Callable:
171+
"""
172+
Converts the weight of a linear module to a LinearObserverTensor.
173+
174+
This function wraps the weight of the given linear module with a LinearObserverTensor,
175+
which enables observation of both input and weight tensors during forward passes.
176+
The wrapped weight is then re-wrapped as a nn.Parameter to maintain compatibility
177+
with PyTorch's module system.
178+
179+
Usaage:
180+
```
181+
linear_module = nn.Linear(10, 20)
182+
quantize_(linear_module, convert_linear_weight_to_observer))
183+
184+
Args:
185+
linear_module (nn.Linear): The linear module to be converted.
186+
input_observer (Optional[AffineQuantizedObserverBase]): Observer for input tensor.
187+
weight_observer (Optional[AffineQuantizedObserverBase]): Observer for weight tensor.
188+
189+
Returns:
190+
nn.Linear: The modified linear module with its weight wrapped in a LinearObserverTensor.
191+
"""
192+
193+
def convert_to_linear_observer(linear_module: nn.Linear):
194+
# Wrap the weight with LinearObserverTensor and then with nn.Parameter
195+
linear_module.weight = nn.Parameter(
196+
to_linear_observer_tensor(
197+
linear_module.weight,
198+
input_observer=input_observer,
199+
weight_observer=weight_observer,
200+
),
201+
requires_grad=linear_module.weight.requires_grad,
202+
)
203+
return linear_module
204+
205+
return convert_to_linear_observer

0 commit comments

Comments
 (0)