|
| 1 | +import gguf |
| 2 | +import torch |
| 3 | + |
| 4 | +from invokeai.backend.quantization.gguf.utils import ( |
| 5 | + DEQUANTIZE_FUNCTIONS, |
| 6 | + TORCH_COMPATIBLE_QTYPES, |
| 7 | + dequantize, |
| 8 | +) |
| 9 | + |
| 10 | + |
| 11 | +class GGMLTensor: |
| 12 | + def __init__(self, data: torch.Tensor, ggml_quantization_type: gguf.GGMLQuantizationType, tensor_shape: torch.Size): |
| 13 | + self._data = data |
| 14 | + self._ggml_quantization_type = ggml_quantization_type |
| 15 | + # The dequantized shape of the tensor. |
| 16 | + self._tensor_shape = tensor_shape |
| 17 | + |
| 18 | + def __repr__(self): |
| 19 | + return f"GGMLTensor(type={self._ggml_quantization_type.name}, dequantized_shape=({self._tensor_shape})" |
| 20 | + |
| 21 | + def get_dequantized_tensor(self, dtype: torch.dtype): |
| 22 | + """Return the dequantized tensor. |
| 23 | +
|
| 24 | + Args: |
| 25 | + dtype: The dtype of the dequantized tensor. |
| 26 | + """ |
| 27 | + if self._ggml_quantization_type in TORCH_COMPATIBLE_QTYPES: |
| 28 | + return self._data.to(dtype) |
| 29 | + elif self._ggml_quantization_type in DEQUANTIZE_FUNCTIONS: |
| 30 | + # TODO(ryand): Look into how the dtype param is intended to be used. |
| 31 | + return dequantize( |
| 32 | + data=self._data, qtype=self._ggml_quantization_type, oshape=self._tensor_shape, dtype=None |
| 33 | + ).to(dtype) |
| 34 | + else: |
| 35 | + # There is no GPU implementation for this quantization type, so fallback to the numpy implementation. |
| 36 | + new = gguf.quants.dequantize(self._data.cpu().numpy(), self._ggml_quantization_type) |
| 37 | + return torch.from_numpy(new).to(self._data.device, dtype=dtype) |
| 38 | + |
| 39 | + @classmethod |
| 40 | + def __torch_function__(cls, func, types, args=(), kwargs=None): |
| 41 | + if kwargs is None: |
| 42 | + kwargs = {} |
| 43 | + |
| 44 | + # Most functions will work by simply running on the dequantized tensors, so we assume this as the default |
| 45 | + # behavior. Over time, we will have to add special handling for exceptions. For example, .to() will need special |
| 46 | + # handling. |
| 47 | + if func in []: |
| 48 | + return NotImplemented |
| 49 | + else: |
| 50 | + # TODO(ryand): Use the highest input precision of non-quantized inputs instead of hardcoding torch.float32. |
| 51 | + dequantized_args = [ |
| 52 | + a.get_dequantized_tensor(dtype=torch.float32) if hasattr(a, "get_dequantized_tensor") else a |
| 53 | + for a in args |
| 54 | + ] |
| 55 | + dequantized_kwargs = { |
| 56 | + k: v.get_dequantized_tensor(dtype=torch.float32) if hasattr(v, "get_dequantized_tensor") else v |
| 57 | + for k, v in kwargs.items() |
| 58 | + } |
| 59 | + return func(*dequantized_args, **dequantized_kwargs) |
0 commit comments