Skip to content

convert-hf : support direct Q8_0 conversion #7234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 28 additions & 44 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,23 +240,6 @@ def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: i
return False

def write_tensors(self):
# same as ggml_compute_fp32_to_bf16 in ggml-impl.h
def np_fp32_to_bf16(n: np.ndarray):
# force nan to quiet
n = np.where((n & 0x7fffffff) > 0x7f800000, (n & 0xffff0000) | (64 << 16), n)
# flush subnormals to zero
n = np.where((n & 0x7f800000) == 0, n & 0x80000000, n)
# round to nearest even
n = (n + (0x7fff + ((n >> 16) & 1))) >> 16
return n.astype(np.int16)

# Doing this row-wise is much, much faster than element-wise, hence the signature
v_fp32_to_bf16 = np.vectorize(np_fp32_to_bf16, otypes=[np.int16], signature="(n)->(n)")
if self.lazy:
# TODO: find a way to implicitly wrap np.vectorize functions
# NOTE: the type is changed to reflect otypes passed to np.vectorize above
v_fp32_to_bf16 = gguf.LazyNumpyTensor._wrap_fn(v_fp32_to_bf16, meta_noop=np.int16)

max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")

for name, data_torch in self.get_tensors():
Expand Down Expand Up @@ -309,27 +292,31 @@ def np_fp32_to_bf16(n: np.ndarray):
))

if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
if self.ftype == gguf.LlamaFileType.MOSTLY_F16:
if self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
data = gguf.quantize_bf16(data)
assert data.dtype == np.int16
data_qtype = gguf.GGMLQuantizationType.BF16

elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 and gguf.can_quantize_to_q8_0(data):
data = gguf.quantize_q8_0(data)
assert data.dtype == np.uint8
data_qtype = gguf.GGMLQuantizationType.Q8_0

else: # default to float16 for quantized tensors
if data_dtype != np.float16:
data = data.astype(np.float16)
data_qtype = gguf.GGMLQuantizationType.F16

elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
if data_dtype != np.float32:
data = data.astype(np.float32)
data = v_fp32_to_bf16(data.view(np.int32))
assert data.dtype == np.int16
data_qtype = gguf.GGMLQuantizationType.BF16

else: # by default, convert to float32
if data_qtype is None: # by default, convert to float32
if data_dtype != np.float32:
data = data.astype(np.float32)
data_qtype = gguf.GGMLQuantizationType.F32

assert data_qtype is not None

block_size, type_size = gguf.GGML_QUANT_SIZES[data_qtype]
# reverse shape to make it similar to the internal ggml dimension order
shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"
shape_str = f"""{{{', '.join(str(n) for n in reversed(
(*data.shape[:-1], data.shape[-1] * data.dtype.itemsize // type_size * block_size))
)}}}"""

# n_dims is implicit in the shape
logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
Expand Down Expand Up @@ -859,6 +846,7 @@ def set_gguf_parameters(self):
self.gguf_writer.add_head_count(head_count)
self.gguf_writer.add_head_count_kv(head_count_kv)
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
self.gguf_writer.add_file_type(self.ftype)

if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
if self.hparams["rope_scaling"].get("type") == "linear":
Expand Down Expand Up @@ -981,6 +969,7 @@ def set_gguf_parameters(self):
self.gguf_writer.add_head_count(head_count)
self.gguf_writer.add_head_count_kv(head_count_kv)
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
self.gguf_writer.add_file_type(self.ftype)

if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
if self.hparams["rope_scaling"].get("type") == "linear":
Expand Down Expand Up @@ -1215,6 +1204,7 @@ def set_gguf_parameters(self):
self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"])
self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_eps", "norm_eps"]))
self.gguf_writer.add_file_type(self.ftype)

_q_norms: list[dict[str, Tensor]] | None = None
_k_norms: list[dict[str, Tensor]] | None = None
Expand Down Expand Up @@ -1591,6 +1581,7 @@ def set_gguf_parameters(self):
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
self.gguf_writer.add_file_type(self.ftype)


@Model.register("Qwen2ForCausalLM")
Expand Down Expand Up @@ -1828,6 +1819,7 @@ def set_gguf_parameters(self):
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
self.gguf_writer.add_file_type(self.ftype)

def shuffle_attn_q_weight(self, data_torch):
assert data_torch.size() == (5120, 5120)
Expand Down Expand Up @@ -2007,6 +1999,7 @@ def set_gguf_parameters(self):
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
self.gguf_writer.add_file_type(self.ftype)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
num_heads = self.hparams["num_attention_heads"]
Expand Down Expand Up @@ -2415,25 +2408,15 @@ class LazyTorchTensor(gguf.LazyBase):
def numpy(self) -> gguf.LazyNumpyTensor:
dtype = self._dtype_map[self.dtype]
return gguf.LazyNumpyTensor(
meta=np.lib.stride_tricks.as_strided(np.zeros(1, dtype), self.shape, (0 for _ in self.shape)),
meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape),
lazy=self._lazy,
args=(self,),
func=(lambda s: s[0].numpy())
)

@classmethod
def eager_to_meta(cls, t: Tensor) -> Tensor:
if t.is_meta:
return t
return t.detach().to("meta")

@classmethod
def meta_with_dtype(cls, m: Tensor, dtype: torch.dtype) -> Tensor:
m = m.detach()
if not m.is_meta:
m = m.to("meta")
m.dtype = dtype
return m
def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: torch.Size) -> Tensor:
return torch.empty(size=shape, dtype=dtype, device="meta")

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
Expand Down Expand Up @@ -2464,8 +2447,8 @@ def parse_args() -> argparse.Namespace:
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
)
parser.add_argument(
"--outtype", type=str, choices=["f32", "f16", "bf16", "auto"], default="f16",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
)
parser.add_argument(
"--bigendian", action="store_true",
Expand Down Expand Up @@ -2523,6 +2506,7 @@ def main() -> None:
"f32": gguf.LlamaFileType.ALL_F32,
"f16": gguf.LlamaFileType.MOSTLY_F16,
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
"auto": gguf.LlamaFileType.GUESSED,
}

Expand Down
1 change: 1 addition & 0 deletions gguf-py/gguf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from .lazy import *
from .gguf_reader import *
from .gguf_writer import *
from .quants import *
from .tensor_mapping import *
from .vocab import *
16 changes: 11 additions & 5 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import numpy as np

from .constants import (
GGML_QUANT_SIZES,
GGUF_DEFAULT_ALIGNMENT,
GGUF_MAGIC,
GGUF_VERSION,
Expand Down Expand Up @@ -195,7 +196,7 @@ def ggml_pad(x: int, n: int) -> int:
return ((x + n - 1) // n) * n

def add_tensor_info(
self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype[np.float16] | np.dtype[np.float32],
self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype,
tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None,
) -> None:
if self.state is not WriterState.EMPTY:
Expand All @@ -208,10 +209,6 @@ def add_tensor_info(
encoded_name = name.encode("utf-8")
self.ti_data += self._pack("Q", len(encoded_name))
self.ti_data += encoded_name
n_dims = len(tensor_shape)
self.ti_data += self._pack("I", n_dims)
for i in range(n_dims):
self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i])
if raw_dtype is None:
if tensor_dtype == np.float16:
dtype = GGMLQuantizationType.F16
Expand All @@ -231,6 +228,15 @@ def add_tensor_info(
raise ValueError("Only F16, F32, F64, I8, I16, I32, I64 tensors are supported for now")
else:
dtype = raw_dtype
if tensor_dtype == np.uint8:
block_size, type_size = GGML_QUANT_SIZES[raw_dtype]
if tensor_shape[-1] % type_size != 0:
raise ValueError(f"Quantized tensor row size ({tensor_shape[-1]}) is not a multiple of {dtype.name} type size ({type_size})")
tensor_shape = tuple(tensor_shape[:-1]) + (tensor_shape[-1] // type_size * block_size,)
Comment on lines +231 to +235
Copy link
Collaborator

@CISC CISC May 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has broken copying of tensors on i-quants (and probably several others as well), using

./gguf-new-metadata.py foo.IQ4_NL.gguf bar.gguf

you now get

ValueError: Quantized tensor row size (4096) is not a multiple of IQ4_NL type size (18)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue seems to be that the type_size is off by 2, however I don't see why the tensor should be reshaped in this scenario, so this should probably be re-evaluated.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for finding this!
I think it also breaks copying of all other quantized tensors in gguf-new-metadata.
Sorry about that.

I think I found a way to fix this while also simplifying what happens to the shape in the round-trip between GGUFReader and GGUFWriter. See #7483

n_dims = len(tensor_shape)
self.ti_data += self._pack("I", n_dims)
for i in range(n_dims):
self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i])
self.ti_data += self._pack("I", dtype)
self.ti_data += self._pack("Q", self.offset_tensor)
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
Expand Down
29 changes: 20 additions & 9 deletions gguf-py/gguf/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import deque

import numpy as np
from numpy._typing import _Shape
from numpy.typing import DTypeLike


Expand Down Expand Up @@ -110,7 +111,7 @@ def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
return o

@classmethod
def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike = False) -> Callable[[Any], Any]:
def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] = False) -> Callable[[Any], Any]:
def wrapped_fn(*args, **kwargs):
if kwargs is None:
kwargs = {}
Expand All @@ -130,9 +131,14 @@ def wrapped_fn(*args, **kwargs):
res = args[0]
assert isinstance(res, cls)
res = res._meta
# allow operations to override the dtype
# allow operations to override the dtype and shape
if meta_noop is not True:
res = cls.meta_with_dtype(res, meta_noop)
if isinstance(meta_noop, tuple):
dtype, shape = meta_noop
assert callable(shape)
res = cls.meta_with_dtype_and_shape(dtype, shape(res.shape))
else:
res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)

if isinstance(res, cls._tensor_type):
def collect_replace(t: LazyBase):
Expand Down Expand Up @@ -168,7 +174,12 @@ def already_eager_to_eager(_t: LazyBase) -> Any:
while _t._data is None:
lt = _t._lazy.popleft()
if lt._data is not None:
raise ValueError(f"{lt} did not belong in the lazy queue")
# Lazy tensor did not belong in the lazy queue.
# Weirdly only happens with Bloom models...
# likely because tensors aren't unique in the queue.
# The final output is still the same as in eager mode,
# so it's safe to ignore this.
continue
assert lt._func is not None
lt._args = cls._recurse_apply(lt._args, already_eager_to_eager)
lt._data = lt._func(lt._args)
Expand All @@ -183,12 +194,12 @@ def already_eager_to_eager(_t: LazyBase) -> Any:

@classmethod
def eager_to_meta(cls, t: Any) -> Any:
return cls.meta_with_dtype(t, t.dtype)
return cls.meta_with_dtype_and_shape(t.dtype, t.shape)

# must be overridden, meta tensor init is backend-specific
@classmethod
@abstractmethod
def meta_with_dtype(cls, m: Any, dtype: Any) -> Any: pass
def meta_with_dtype_and_shape(cls, dtype: Any, shape: Any) -> Any: pass

@classmethod
def from_eager(cls, t: Any) -> Any:
Expand All @@ -205,15 +216,15 @@ class LazyNumpyTensor(LazyBase):
_tensor_type = np.ndarray

@classmethod
def meta_with_dtype(cls, m: np.ndarray[Any, Any], dtype: DTypeLike) -> np.ndarray[Any, Any]:
def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: _Shape) -> np.ndarray[Any, Any]:
# The initial idea was to use np.nan as the fill value,
# but non-float types like np.int16 can't use that.
# So zero it is.
cheat = np.zeros(1, dtype)
return np.lib.stride_tricks.as_strided(cheat, m.shape, (0 for _ in m.shape))
return np.lib.stride_tricks.as_strided(cheat, shape, (0 for _ in shape))

def astype(self, dtype, *args, **kwargs):
meta = type(self).meta_with_dtype(self._meta, dtype)
meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape)
full_args = (self, dtype,) + args
# very important to pass the shared _lazy deque, or else there's an infinite loop somewhere.
return type(self)(meta=meta, args=full_args, lazy=self._lazy, func=(lambda a: a[0].astype(*a[1:], **kwargs)))
Expand Down
Loading
Loading