Skip to content

Commit f06765d

Browse files
RyanJDickhipsterusername
authored andcommitted
Get alternative GGUF implementation working... barely.
1 parent f347b26 commit f06765d

File tree

4 files changed

+124
-75
lines changed

4 files changed

+124
-75
lines changed

invokeai/backend/model_manager/load/model_loaders/flux.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
convert_bundle_to_flux_transformer_checkpoint,
3838
)
3939
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
40-
from invokeai.backend.quantization.gguf.torch_patcher import GGUFPatcher
4140
from invokeai.backend.util.silence_warnings import SilenceWarnings
4241

4342
try:
@@ -234,7 +233,7 @@ def _load_from_singlefile(
234233
assert isinstance(config, MainGGUFCheckpointConfig)
235234
model_path = Path(config.path)
236235

237-
with SilenceWarnings(), GGUFPatcher().wrap():
236+
with SilenceWarnings():
238237
# Load the state dict and patcher
239238
sd = gguf_sd_loader(model_path)
240239
# Initialize the model

invokeai/backend/quantization/gguf/ggml_tensor.py

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,48 @@
88
)
99

1010

11-
class GGMLTensor:
11+
def dequantize_and_run(func, args, kwargs):
12+
# TODO(ryand): Use the highest input precision of non-quantized inputs instead of hardcoding torch.float32.
13+
dequantized_args = [
14+
a.get_dequantized_tensor(dtype=torch.bfloat16) if hasattr(a, "get_dequantized_tensor") else a for a in args
15+
]
16+
dequantized_kwargs = {
17+
k: v.get_dequantized_tensor(dtype=torch.bfloat16) if hasattr(v, "get_dequantized_tensor") else v
18+
for k, v in kwargs.items()
19+
}
20+
return func(*dequantized_args, **dequantized_kwargs)
21+
22+
23+
def apply_to_quantized_tensor(func, args, kwargs):
24+
ggml_tensor = args[0]
25+
assert isinstance(ggml_tensor, GGMLTensor)
26+
new_data = func(ggml_tensor._data, *args[1:], **kwargs)
27+
return GGMLTensor(new_data, ggml_tensor._ggml_quantization_type, ggml_tensor._tensor_shape)
28+
29+
30+
GGML_TENSOR_OP_TABLE = {
31+
torch.ops.aten.detach.default: apply_to_quantized_tensor,
32+
torch.ops.aten._to_copy.default: apply_to_quantized_tensor,
33+
# --
34+
torch.ops.aten.t.default: dequantize_and_run,
35+
torch.ops.aten.addmm.default: dequantize_and_run,
36+
torch.ops.aten.mul.Tensor: dequantize_and_run,
37+
}
38+
39+
40+
class GGMLTensor(torch.Tensor):
41+
@staticmethod
42+
def __new__(cls, data: torch.Tensor, ggml_quantization_type: gguf.GGMLQuantizationType, tensor_shape: torch.Size):
43+
return torch.Tensor._make_wrapper_subclass(
44+
cls,
45+
data.shape,
46+
dtype=data.dtype,
47+
layout=data.layout,
48+
device=data.device,
49+
strides=data.stride(),
50+
storage_offset=data.storage_offset(),
51+
)
52+
1253
def __init__(self, data: torch.Tensor, ggml_quantization_type: gguf.GGMLQuantizationType, tensor_shape: torch.Size):
1354
self._data = data
1455
self._ggml_quantization_type = ggml_quantization_type
@@ -18,6 +59,17 @@ def __init__(self, data: torch.Tensor, ggml_quantization_type: gguf.GGMLQuantiza
1859
def __repr__(self):
1960
return f"GGMLTensor(type={self._ggml_quantization_type.name}, dequantized_shape=({self._tensor_shape})"
2061

62+
def size(self):
63+
return self._tensor_shape
64+
65+
@property
66+
def shape(self):
67+
return self.size()
68+
69+
def requires_grad_(self, requires_grad: bool = True):
70+
# TODO(ryand): Think about whether we should set requires_grad on the underlying tensor.
71+
return self
72+
2173
def get_dequantized_tensor(self, dtype: torch.dtype):
2274
"""Return the dequantized tensor.
2375
@@ -37,23 +89,7 @@ def get_dequantized_tensor(self, dtype: torch.dtype):
3789
return torch.from_numpy(new).to(self._data.device, dtype=dtype)
3890

3991
@classmethod
40-
def __torch_function__(cls, func, types, args=(), kwargs=None):
41-
if kwargs is None:
42-
kwargs = {}
43-
44-
# Most functions will work by simply running on the dequantized tensors, so we assume this as the default
45-
# behavior. Over time, we will have to add special handling for exceptions. For example, .to() will need special
46-
# handling.
47-
if func in []:
48-
return NotImplemented
49-
else:
50-
# TODO(ryand): Use the highest input precision of non-quantized inputs instead of hardcoding torch.float32.
51-
dequantized_args = [
52-
a.get_dequantized_tensor(dtype=torch.float32) if hasattr(a, "get_dequantized_tensor") else a
53-
for a in args
54-
]
55-
dequantized_kwargs = {
56-
k: v.get_dequantized_tensor(dtype=torch.float32) if hasattr(v, "get_dequantized_tensor") else v
57-
for k, v in kwargs.items()
58-
}
59-
return func(*dequantized_args, **dequantized_kwargs)
92+
def __torch_dispatch__(cls, func, types, args, kwargs):
93+
if func in GGML_TENSOR_OP_TABLE:
94+
return GGML_TENSOR_OP_TABLE[func](func, args, kwargs)
95+
raise NotImplementedError(f"Unsupported function {func}")

invokeai/backend/quantization/gguf/loaders.py

Lines changed: 64 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -5,64 +5,78 @@
55
import gguf
66
import torch
77

8+
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
89
from invokeai.backend.quantization.gguf.layers import GGUFTensor
9-
from invokeai.backend.quantization.gguf.utils import detect_arch
10+
from invokeai.backend.quantization.gguf.utils import TORCH_COMPATIBLE_QTYPES
1011

1112

12-
def gguf_sd_loader(
13-
path: Path, handle_prefix: str = "model.diffusion_model.", data_type: torch.dtype = torch.bfloat16
14-
) -> dict[str, GGUFTensor]:
15-
"""
16-
Read state dict as fake tensors
17-
"""
13+
def gguf_sd_loader(path: Path) -> dict[str, GGUFTensor]:
1814
reader = gguf.GGUFReader(path)
1915

20-
prefix_len = len(handle_prefix)
21-
tensor_names = {tensor.name for tensor in reader.tensors}
22-
has_prefix = any(s.startswith(handle_prefix) for s in tensor_names)
23-
24-
tensors: list[tuple[str, gguf.ReaderTensor]] = []
16+
sd: dict[str, GGUFTensor] = {}
2517
for tensor in reader.tensors:
26-
sd_key = tensor_name = tensor.name
27-
if has_prefix:
28-
if not tensor_name.startswith(handle_prefix):
29-
continue
30-
sd_key = tensor_name[prefix_len:]
31-
tensors.append((sd_key, tensor))
18+
torch_tensor = torch.from_numpy(tensor.data)
19+
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
20+
if tensor.tensor_type in TORCH_COMPATIBLE_QTYPES:
21+
torch_tensor = torch_tensor.view(*shape)
22+
sd[tensor.name] = GGMLTensor(torch_tensor, ggml_quantization_type=tensor.tensor_type, tensor_shape=shape)
23+
return sd
3224

33-
# detect and verify architecture
34-
compat = None
35-
arch_str = None
36-
arch_field = reader.get_field("general.architecture")
37-
if arch_field is not None:
38-
if len(arch_field.types) != 1 or arch_field.types[0] != gguf.GGUFValueType.STRING:
39-
raise TypeError(f"Bad type for GGUF general.architecture key: expected string, got {arch_field.types!r}")
40-
arch_str = str(arch_field.parts[arch_field.data[-1]], encoding="utf-8")
41-
if arch_str not in {"flux"}:
42-
raise ValueError(f"Unexpected architecture type in GGUF file, expected flux, but got {arch_str!r}")
43-
else:
44-
arch_str = detect_arch({val[0] for val in tensors})
45-
compat = "sd.cpp"
4625

47-
# main loading loop
48-
state_dict: dict[str, GGUFTensor] = {}
49-
qtype_dict: dict[str, int] = {}
50-
for sd_key, tensor in tensors:
51-
tensor_name = tensor.name
52-
tensor_type_str = str(tensor.tensor_type)
53-
torch_tensor = torch.from_numpy(tensor.data) # mmap
26+
# def gguf_sd_loader(
27+
# path: Path, handle_prefix: str = "model.diffusion_model.", data_type: torch.dtype = torch.bfloat16
28+
# ) -> dict[str, GGUFTensor]:
29+
# """
30+
# Read state dict as fake tensors
31+
# """
32+
# reader = gguf.GGUFReader(path)
5433

55-
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
56-
# Workaround for stable-diffusion.cpp SDXL detection.
57-
if compat == "sd.cpp" and arch_str == "sdxl":
58-
if tensor_name.endswith((".proj_in.weight", ".proj_out.weight")):
59-
while len(shape) > 2 and shape[-1] == 1:
60-
shape = shape[:-1]
34+
# prefix_len = len(handle_prefix)
35+
# tensor_names = {tensor.name for tensor in reader.tensors}
36+
# has_prefix = any(s.startswith(handle_prefix) for s in tensor_names)
6137

62-
# add to state dict
63-
if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}:
64-
torch_tensor = torch_tensor.view(*shape)
65-
state_dict[sd_key] = GGUFTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape)
66-
qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1
38+
# tensors: list[tuple[str, gguf.ReaderTensor]] = []
39+
# for tensor in reader.tensors:
40+
# sd_key = tensor_name = tensor.name
41+
# if has_prefix:
42+
# if not tensor_name.startswith(handle_prefix):
43+
# continue
44+
# sd_key = tensor_name[prefix_len:]
45+
# tensors.append((sd_key, tensor))
46+
47+
# # detect and verify architecture
48+
# compat = None
49+
# arch_str = None
50+
# arch_field = reader.get_field("general.architecture")
51+
# if arch_field is not None:
52+
# if len(arch_field.types) != 1 or arch_field.types[0] != gguf.GGUFValueType.STRING:
53+
# raise TypeError(f"Bad type for GGUF general.architecture key: expected string, got {arch_field.types!r}")
54+
# arch_str = str(arch_field.parts[arch_field.data[-1]], encoding="utf-8")
55+
# if arch_str not in {"flux"}:
56+
# raise ValueError(f"Unexpected architecture type in GGUF file, expected flux, but got {arch_str!r}")
57+
# else:
58+
# arch_str = detect_arch({val[0] for val in tensors})
59+
# compat = "sd.cpp"
60+
61+
# # main loading loop
62+
# state_dict: dict[str, GGUFTensor] = {}
63+
# qtype_dict: dict[str, int] = {}
64+
# for sd_key, tensor in tensors:
65+
# tensor_name = tensor.name
66+
# tensor_type_str = str(tensor.tensor_type)
67+
# torch_tensor = torch.from_numpy(tensor.data) # mmap
68+
69+
# shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
70+
# # Workaround for stable-diffusion.cpp SDXL detection.
71+
# if compat == "sd.cpp" and arch_str == "sdxl":
72+
# if tensor_name.endswith((".proj_in.weight", ".proj_out.weight")):
73+
# while len(shape) > 2 and shape[-1] == 1:
74+
# shape = shape[:-1]
75+
76+
# # add to state dict
77+
# if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}:
78+
# torch_tensor = torch_tensor.view(*shape)
79+
# state_dict[sd_key] = GGUFTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape)
80+
# qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1
6781

68-
return state_dict
82+
# return state_dict

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@ dependencies = [
5151
"sentencepiece==0.2.0",
5252
"spandrel==0.3.4",
5353
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
54-
"torch==2.2.2",
54+
"torch==2.4.1",
5555
"torchmetrics==0.11.4",
5656
"torchsde==0.2.6",
57-
"torchvision==0.17.2",
57+
"torchvision==0.19.1",
5858
"transformers==4.41.1",
5959

6060
# Core application dependencies, pinned for reproducible builds.

0 commit comments

Comments
 (0)