Skip to content

Commit f347b26

Browse files
RyanJDickhipsterusername
authored andcommitted
Initial experimentation with Tensor-like extension for GGUF.
1 parent c665cf3 commit f347b26

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import gguf
2+
import torch
3+
4+
from invokeai.backend.quantization.gguf.utils import (
5+
DEQUANTIZE_FUNCTIONS,
6+
TORCH_COMPATIBLE_QTYPES,
7+
dequantize,
8+
)
9+
10+
11+
class GGMLTensor:
12+
def __init__(self, data: torch.Tensor, ggml_quantization_type: gguf.GGMLQuantizationType, tensor_shape: torch.Size):
13+
self._data = data
14+
self._ggml_quantization_type = ggml_quantization_type
15+
# The dequantized shape of the tensor.
16+
self._tensor_shape = tensor_shape
17+
18+
def __repr__(self):
19+
return f"GGMLTensor(type={self._ggml_quantization_type.name}, dequantized_shape=({self._tensor_shape})"
20+
21+
def get_dequantized_tensor(self, dtype: torch.dtype):
22+
"""Return the dequantized tensor.
23+
24+
Args:
25+
dtype: The dtype of the dequantized tensor.
26+
"""
27+
if self._ggml_quantization_type in TORCH_COMPATIBLE_QTYPES:
28+
return self._data.to(dtype)
29+
elif self._ggml_quantization_type in DEQUANTIZE_FUNCTIONS:
30+
# TODO(ryand): Look into how the dtype param is intended to be used.
31+
return dequantize(
32+
data=self._data, qtype=self._ggml_quantization_type, oshape=self._tensor_shape, dtype=None
33+
).to(dtype)
34+
else:
35+
# There is no GPU implementation for this quantization type, so fallback to the numpy implementation.
36+
new = gguf.quants.dequantize(self._data.cpu().numpy(), self._ggml_quantization_type)
37+
return torch.from_numpy(new).to(self._data.device, dtype=dtype)
38+
39+
@classmethod
40+
def __torch_function__(cls, func, types, args=(), kwargs=None):
41+
if kwargs is None:
42+
kwargs = {}
43+
44+
# Most functions will work by simply running on the dequantized tensors, so we assume this as the default
45+
# behavior. Over time, we will have to add special handling for exceptions. For example, .to() will need special
46+
# handling.
47+
if func in []:
48+
return NotImplemented
49+
else:
50+
# TODO(ryand): Use the highest input precision of non-quantized inputs instead of hardcoding torch.float32.
51+
dequantized_args = [
52+
a.get_dequantized_tensor(dtype=torch.float32) if hasattr(a, "get_dequantized_tensor") else a
53+
for a in args
54+
]
55+
dequantized_kwargs = {
56+
k: v.get_dequantized_tensor(dtype=torch.float32) if hasattr(v, "get_dequantized_tensor") else v
57+
for k, v in kwargs.items()
58+
}
59+
return func(*dequantized_args, **dequantized_kwargs)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import torch
2+
3+
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
4+
from invokeai.backend.quantization.gguf.layers import GGUFTensor
5+
6+
7+
def test_ggml_tensor():
8+
"""Smoke test that multiplication works on a GGMLTensor."""
9+
weight: GGUFTensor = torch.load("tests/assets/gguf_qweight.pt")
10+
tensor_shape = weight.tensor_shape
11+
tensor_type = weight.tensor_type
12+
data = torch.Tensor(weight.data)
13+
14+
ggml_tensor = GGMLTensor(data, tensor_type, tensor_shape)
15+
ones = torch.ones([1], dtype=torch.float32)
16+
17+
x = ggml_tensor * ones

0 commit comments

Comments
 (0)