Skip to content

[WIP] Codebook quantization flow #1299

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions test/prototype/test_codebook_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import unittest

import torch

from torchao.prototype.quantization.codebook import (
CodebookQuantizedTensor,
choose_qparams_codebook,
)
from torchao.quantization.utils import compute_error


class TestCodebookQuantization(unittest.TestCase):
def setUp(self):
torch.manual_seed(123)
self.input = torch.randn(100, 256, dtype=torch.float32)
self.block_size = (1, 1)
self.scale_block_size = 64
self.code_dtype = torch.uint8
self.chunk_size = 1024

def test_choose_qparams_codebook(self):
codebook, scales = choose_qparams_codebook(
self.input,
block_size=self.block_size,
scale_block_size=self.scale_block_size,
code_dtype=self.code_dtype,
)
self.assertEqual(codebook.dim(), len(self.block_size) + 1)

self.assertFalse(torch.isnan(codebook).any())
self.assertFalse(torch.isnan(scales).any())

def test_codebook_quantized_tensor_from_float(self):
cqt = CodebookQuantizedTensor.from_float(
self.input,
block_size=self.block_size,
code_dtype=self.code_dtype,
scale_block_size=self.scale_block_size,
chunk_size=self.chunk_size,
)

dequant = cqt.dequantize()

sqnr = compute_error(dequant, self.input)
self.assertGreater(sqnr, 30)

def test_codebook_quantized_tensor_from_float2(self):
block_size = (1, 16)
code_dtype = torch.int32
scale_block_size = self.input.shape[1]

cqt = CodebookQuantizedTensor.from_float(
self.input,
block_size=block_size,
code_dtype=code_dtype,
scale_block_size=scale_block_size,
chunk_size=self.chunk_size,
)

dequant = cqt.dequantize()

sqnr = compute_error(dequant, self.input)
self.assertGreater(sqnr, 30)


if __name__ == "__main__":
unittest.main()
4 changes: 4 additions & 0 deletions torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ def run_evaluation(
)
model.to(device)
model.reset_caches()
if "codebook" in quantization:
from torchao.prototype.quantization.codebook import codebook_weight_only
model.to(device)
quantize_(model, codebook_weight_only(dtype=torch.uint4, scale_block_size=64))

if compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)
Expand Down
4 changes: 4 additions & 0 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,10 @@ def ffn_or_attn_only(mod, fqn):

# do autoquantization
model.finalize_autoquant()
elif "codebook" in quantization:
from torchao.prototype.quantization.codebook import codebook_weight_only
model.to(device)
quantize_(model, codebook_weight_only(dtype=torch.uint4, scale_block_size=64))

else:
if not TORCH_VERSION_AT_LEAST_2_5:
Expand Down
14 changes: 14 additions & 0 deletions torchao/prototype/quantization/codebook/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .codebook_ops import (
choose_qparams_codebook,
dequantize_codebook,
quantize_codebook,
)
from .codebook_quantized_tensor import CodebookQuantizedTensor, codebook_weight_only

__all__ = [
"CodebookQuantizedTensor",
"codebook_weight_only",
"quantize_codebook",
"dequantize_codebook",
"choose_qparams_codebook",
]
Loading
Loading