|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +from torchao.utils import ( |
| 4 | + _implements, |
| 5 | + _dispatch__torch_function__, |
| 6 | + _dispatch__torch_dispatch__, |
| 7 | +) |
| 8 | +from typing import Callable, Optional, Dict |
| 9 | +from torch.utils._python_dispatch import return_and_correct_aliasing |
| 10 | +from torchao.utils import ( |
| 11 | + TorchAOBaseTensor, |
| 12 | + TORCH_VERSION_AT_LEAST_2_5, |
| 13 | +) |
| 14 | + |
| 15 | +from torchao.quantization.observer import AffineQuantizedObserverBase |
| 16 | + |
| 17 | +__all__ = [ |
| 18 | + "LinearObserverTensor", |
| 19 | + "to_linear_observer_tensor", |
| 20 | + "insert_observers", |
| 21 | +] |
| 22 | + |
| 23 | +aten = torch.ops.aten |
| 24 | +Tensor = torch.Tensor |
| 25 | + |
| 26 | + |
| 27 | +class LinearObserverTensor(TorchAOBaseTensor): |
| 28 | + """ |
| 29 | + This subclass of Tensor is used in conjuction with a static calibration flow. |
| 30 | + The flow is broken up into 3 parts; |
| 31 | + 1. Insert the LinearObserverTensor subclass into the model's nn.Linear layers |
| 32 | + 2. Run the model with a calibration dataset, the observer will record the min/max of the input and weight |
| 33 | + 3. quantize_ the model to static using the statistics recorded by the observer |
| 34 | +
|
| 35 | + This subclass wraps the original weight tensor on the nn.Linear layer. When forward is called, the observer |
| 36 | + will first calculat statistics on BOTH the input and weight, and then run the linear op. |
| 37 | + """ |
| 38 | + |
| 39 | + original_weight_tensor: torch.Tensor |
| 40 | + input_observer: Optional[AffineQuantizedObserverBase] |
| 41 | + weight_observer: Optional[AffineQuantizedObserverBase] |
| 42 | + |
| 43 | + def __new__( |
| 44 | + cls, |
| 45 | + original_weight_tensor: torch.Tensor, |
| 46 | + input_observer: Optional[AffineQuantizedObserverBase] = None, |
| 47 | + weight_observer: Optional[AffineQuantizedObserverBase] = None, |
| 48 | + ): |
| 49 | + kwargs = {} |
| 50 | + dtype = original_weight_tensor.dtype |
| 51 | + kwargs["dtype"] = dtype |
| 52 | + kwargs["requires_grad"] = False |
| 53 | + kwargs["device"] = original_weight_tensor.device |
| 54 | + shape = original_weight_tensor.shape |
| 55 | + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] |
| 56 | + |
| 57 | + def __init__( |
| 58 | + self, |
| 59 | + original_weight_tensor: torch.Tensor, |
| 60 | + input_observer: Optional[AffineQuantizedObserverBase] = None, |
| 61 | + weight_observer: Optional[AffineQuantizedObserverBase] = None, |
| 62 | + ): |
| 63 | + self.original_weight_tensor = original_weight_tensor |
| 64 | + self.input_observer = input_observer |
| 65 | + self.weight_observer = weight_observer |
| 66 | + |
| 67 | + def __repr__(self): |
| 68 | + return ( |
| 69 | + f"LinearObserverTensor(\n" |
| 70 | + f"original_weight={self.original_weight_tensor}\n" |
| 71 | + f"input_observer={self.input_observer.__class__.__name__ if self.input_observer else None}\n" |
| 72 | + f"weight_observer={self.weight_observer.__class__.__name__ if self.weight_observer else None}\n)" |
| 73 | + ) |
| 74 | + |
| 75 | + def __tensor_flatten__(self): |
| 76 | + return ["original_weight_tensor"], [self.input_observer, self.weight_observer] |
| 77 | + |
| 78 | + @classmethod |
| 79 | + def __tensor_unflatten__( |
| 80 | + cls, |
| 81 | + tensor_data_dict: Dict[str, Tensor], |
| 82 | + tensor_attributes, |
| 83 | + outer_size, |
| 84 | + outer_stride, |
| 85 | + ): |
| 86 | + original_weight_tensor = tensor_data_dict["original_weight_tensor"] |
| 87 | + (input_observer, weight_observer) = tensor_attributes |
| 88 | + return cls(original_weight_tensor, input_observer, weight_observer) |
| 89 | + |
| 90 | + @classmethod |
| 91 | + def from_float( |
| 92 | + cls, |
| 93 | + original_weight_tensor: Tensor, |
| 94 | + input_observer: Optional[AffineQuantizedObserverBase] = None, |
| 95 | + weight_observer: Optional[AffineQuantizedObserverBase] = None, |
| 96 | + ): |
| 97 | + return cls(original_weight_tensor, input_observer, weight_observer) |
| 98 | + |
| 99 | + def _apply_fn_to_data(self, fn: Callable): |
| 100 | + """Applies a fn to the tensor component of the LinearObserverTensor""" |
| 101 | + return self.__class__( |
| 102 | + fn(self.original_weight_tensor), |
| 103 | + self.input_observer, |
| 104 | + self.weight_observer, |
| 105 | + ) |
| 106 | + |
| 107 | + def to(self, *args, **kwargs): |
| 108 | + kwargs = self._get_to_kwargs(*args, **kwargs) |
| 109 | + return self._apply_fn_to_data(lambda x: x.to(**kwargs)) |
| 110 | + |
| 111 | + implements = classmethod(_implements) |
| 112 | + __torch_function__ = classmethod(_dispatch__torch_function__) |
| 113 | + __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) |
| 114 | + |
| 115 | + |
| 116 | +implements = LinearObserverTensor.implements |
| 117 | + |
| 118 | + |
| 119 | +@implements(torch.nn.functional.linear) |
| 120 | +def _(func, types, args, kwargs): |
| 121 | + input_tensor, weight_tensor, bias = ( |
| 122 | + args[0], |
| 123 | + args[1], |
| 124 | + args[2] if len(args) > 2 else None, |
| 125 | + ) |
| 126 | + if weight_tensor.input_observer is not None: |
| 127 | + input_tensor = weight_tensor.input_observer(input_tensor) |
| 128 | + if weight_tensor.weight_observer is not None: |
| 129 | + weight_tensor = weight_tensor.weight_observer( |
| 130 | + weight_tensor.original_weight_tensor |
| 131 | + ) |
| 132 | + |
| 133 | + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) |
| 134 | + |
| 135 | + |
| 136 | +@implements(aten.detach.default) |
| 137 | +def _(func, types, args, kwargs): |
| 138 | + return return_and_correct_aliasing( |
| 139 | + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) |
| 140 | + ) |
| 141 | + |
| 142 | + |
| 143 | +@implements(aten.clone.default) |
| 144 | +def _(func, types, args, kwargs): |
| 145 | + return return_and_correct_aliasing( |
| 146 | + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) |
| 147 | + ) |
| 148 | + |
| 149 | + |
| 150 | +@implements(aten._to_copy.default) |
| 151 | +def _(func, types, args, kwargs): |
| 152 | + return return_and_correct_aliasing( |
| 153 | + func, |
| 154 | + args, |
| 155 | + kwargs, |
| 156 | + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), |
| 157 | + ) |
| 158 | + |
| 159 | + |
| 160 | +to_linear_observer_tensor = LinearObserverTensor.from_float |
| 161 | + |
| 162 | +if TORCH_VERSION_AT_LEAST_2_5: |
| 163 | + # Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` |
| 164 | + torch.serialization.add_safe_globals([LinearObserverTensor]) |
| 165 | + |
| 166 | + |
| 167 | +def insert_observers( |
| 168 | + input_observer: Optional[AffineQuantizedObserverBase], |
| 169 | + weight_observer: Optional[AffineQuantizedObserverBase], |
| 170 | +) -> Callable: |
| 171 | + """ |
| 172 | + Converts the weight of a linear module to a LinearObserverTensor. |
| 173 | +
|
| 174 | + This function wraps the weight of the given linear module with a LinearObserverTensor, |
| 175 | + which enables observation of both input and weight tensors during forward passes. |
| 176 | + The wrapped weight is then re-wrapped as a nn.Parameter to maintain compatibility |
| 177 | + with PyTorch's module system. |
| 178 | +
|
| 179 | + Usaage: |
| 180 | + ``` |
| 181 | + linear_module = nn.Linear(10, 20) |
| 182 | + quantize_(linear_module, convert_linear_weight_to_observer)) |
| 183 | +
|
| 184 | + Args: |
| 185 | + linear_module (nn.Linear): The linear module to be converted. |
| 186 | + input_observer (Optional[AffineQuantizedObserverBase]): Observer for input tensor. |
| 187 | + weight_observer (Optional[AffineQuantizedObserverBase]): Observer for weight tensor. |
| 188 | +
|
| 189 | + Returns: |
| 190 | + nn.Linear: The modified linear module with its weight wrapped in a LinearObserverTensor. |
| 191 | + """ |
| 192 | + |
| 193 | + def convert_to_linear_observer(linear_module: nn.Linear): |
| 194 | + # Wrap the weight with LinearObserverTensor and then with nn.Parameter |
| 195 | + linear_module.weight = nn.Parameter( |
| 196 | + to_linear_observer_tensor( |
| 197 | + linear_module.weight, |
| 198 | + input_observer=input_observer, |
| 199 | + weight_observer=weight_observer, |
| 200 | + ), |
| 201 | + requires_grad=linear_module.weight.requires_grad, |
| 202 | + ) |
| 203 | + return linear_module |
| 204 | + |
| 205 | + return convert_to_linear_observer |
0 commit comments