|
5 | 5 | import gguf
|
6 | 6 | import torch
|
7 | 7 |
|
| 8 | +from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor |
8 | 9 | from invokeai.backend.quantization.gguf.layers import GGUFTensor
|
9 |
| -from invokeai.backend.quantization.gguf.utils import detect_arch |
| 10 | +from invokeai.backend.quantization.gguf.utils import TORCH_COMPATIBLE_QTYPES |
10 | 11 |
|
11 | 12 |
|
12 |
| -def gguf_sd_loader( |
13 |
| - path: Path, handle_prefix: str = "model.diffusion_model.", data_type: torch.dtype = torch.bfloat16 |
14 |
| -) -> dict[str, GGUFTensor]: |
15 |
| - """ |
16 |
| - Read state dict as fake tensors |
17 |
| - """ |
| 13 | +def gguf_sd_loader(path: Path) -> dict[str, GGUFTensor]: |
18 | 14 | reader = gguf.GGUFReader(path)
|
19 | 15 |
|
20 |
| - prefix_len = len(handle_prefix) |
21 |
| - tensor_names = {tensor.name for tensor in reader.tensors} |
22 |
| - has_prefix = any(s.startswith(handle_prefix) for s in tensor_names) |
23 |
| - |
24 |
| - tensors: list[tuple[str, gguf.ReaderTensor]] = [] |
| 16 | + sd: dict[str, GGUFTensor] = {} |
25 | 17 | for tensor in reader.tensors:
|
26 |
| - sd_key = tensor_name = tensor.name |
27 |
| - if has_prefix: |
28 |
| - if not tensor_name.startswith(handle_prefix): |
29 |
| - continue |
30 |
| - sd_key = tensor_name[prefix_len:] |
31 |
| - tensors.append((sd_key, tensor)) |
| 18 | + torch_tensor = torch.from_numpy(tensor.data) |
| 19 | + shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape))) |
| 20 | + if tensor.tensor_type in TORCH_COMPATIBLE_QTYPES: |
| 21 | + torch_tensor = torch_tensor.view(*shape) |
| 22 | + sd[tensor.name] = GGMLTensor(torch_tensor, ggml_quantization_type=tensor.tensor_type, tensor_shape=shape) |
| 23 | + return sd |
32 | 24 |
|
33 |
| - # detect and verify architecture |
34 |
| - compat = None |
35 |
| - arch_str = None |
36 |
| - arch_field = reader.get_field("general.architecture") |
37 |
| - if arch_field is not None: |
38 |
| - if len(arch_field.types) != 1 or arch_field.types[0] != gguf.GGUFValueType.STRING: |
39 |
| - raise TypeError(f"Bad type for GGUF general.architecture key: expected string, got {arch_field.types!r}") |
40 |
| - arch_str = str(arch_field.parts[arch_field.data[-1]], encoding="utf-8") |
41 |
| - if arch_str not in {"flux"}: |
42 |
| - raise ValueError(f"Unexpected architecture type in GGUF file, expected flux, but got {arch_str!r}") |
43 |
| - else: |
44 |
| - arch_str = detect_arch({val[0] for val in tensors}) |
45 |
| - compat = "sd.cpp" |
46 | 25 |
|
47 |
| - # main loading loop |
48 |
| - state_dict: dict[str, GGUFTensor] = {} |
49 |
| - qtype_dict: dict[str, int] = {} |
50 |
| - for sd_key, tensor in tensors: |
51 |
| - tensor_name = tensor.name |
52 |
| - tensor_type_str = str(tensor.tensor_type) |
53 |
| - torch_tensor = torch.from_numpy(tensor.data) # mmap |
| 26 | +# def gguf_sd_loader( |
| 27 | +# path: Path, handle_prefix: str = "model.diffusion_model.", data_type: torch.dtype = torch.bfloat16 |
| 28 | +# ) -> dict[str, GGUFTensor]: |
| 29 | +# """ |
| 30 | +# Read state dict as fake tensors |
| 31 | +# """ |
| 32 | +# reader = gguf.GGUFReader(path) |
54 | 33 |
|
55 |
| - shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape))) |
56 |
| - # Workaround for stable-diffusion.cpp SDXL detection. |
57 |
| - if compat == "sd.cpp" and arch_str == "sdxl": |
58 |
| - if tensor_name.endswith((".proj_in.weight", ".proj_out.weight")): |
59 |
| - while len(shape) > 2 and shape[-1] == 1: |
60 |
| - shape = shape[:-1] |
| 34 | +# prefix_len = len(handle_prefix) |
| 35 | +# tensor_names = {tensor.name for tensor in reader.tensors} |
| 36 | +# has_prefix = any(s.startswith(handle_prefix) for s in tensor_names) |
61 | 37 |
|
62 |
| - # add to state dict |
63 |
| - if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}: |
64 |
| - torch_tensor = torch_tensor.view(*shape) |
65 |
| - state_dict[sd_key] = GGUFTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape) |
66 |
| - qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1 |
| 38 | +# tensors: list[tuple[str, gguf.ReaderTensor]] = [] |
| 39 | +# for tensor in reader.tensors: |
| 40 | +# sd_key = tensor_name = tensor.name |
| 41 | +# if has_prefix: |
| 42 | +# if not tensor_name.startswith(handle_prefix): |
| 43 | +# continue |
| 44 | +# sd_key = tensor_name[prefix_len:] |
| 45 | +# tensors.append((sd_key, tensor)) |
| 46 | + |
| 47 | +# # detect and verify architecture |
| 48 | +# compat = None |
| 49 | +# arch_str = None |
| 50 | +# arch_field = reader.get_field("general.architecture") |
| 51 | +# if arch_field is not None: |
| 52 | +# if len(arch_field.types) != 1 or arch_field.types[0] != gguf.GGUFValueType.STRING: |
| 53 | +# raise TypeError(f"Bad type for GGUF general.architecture key: expected string, got {arch_field.types!r}") |
| 54 | +# arch_str = str(arch_field.parts[arch_field.data[-1]], encoding="utf-8") |
| 55 | +# if arch_str not in {"flux"}: |
| 56 | +# raise ValueError(f"Unexpected architecture type in GGUF file, expected flux, but got {arch_str!r}") |
| 57 | +# else: |
| 58 | +# arch_str = detect_arch({val[0] for val in tensors}) |
| 59 | +# compat = "sd.cpp" |
| 60 | + |
| 61 | +# # main loading loop |
| 62 | +# state_dict: dict[str, GGUFTensor] = {} |
| 63 | +# qtype_dict: dict[str, int] = {} |
| 64 | +# for sd_key, tensor in tensors: |
| 65 | +# tensor_name = tensor.name |
| 66 | +# tensor_type_str = str(tensor.tensor_type) |
| 67 | +# torch_tensor = torch.from_numpy(tensor.data) # mmap |
| 68 | + |
| 69 | +# shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape))) |
| 70 | +# # Workaround for stable-diffusion.cpp SDXL detection. |
| 71 | +# if compat == "sd.cpp" and arch_str == "sdxl": |
| 72 | +# if tensor_name.endswith((".proj_in.weight", ".proj_out.weight")): |
| 73 | +# while len(shape) > 2 and shape[-1] == 1: |
| 74 | +# shape = shape[:-1] |
| 75 | + |
| 76 | +# # add to state dict |
| 77 | +# if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}: |
| 78 | +# torch_tensor = torch_tensor.view(*shape) |
| 79 | +# state_dict[sd_key] = GGUFTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape) |
| 80 | +# qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1 |
67 | 81 |
|
68 |
| - return state_dict |
| 82 | +# return state_dict |
0 commit comments