Skip to content

Commit f4af459

Browse files
DerekLiu35amdfaa
authored andcommitted
[WIP] Codebook quantization flow (#1299)
* Add codebook_ops * Add codebook_quanized_tensor * Add __init__.py * Fix uint8 indexing * Add codebook_weight_only * add codebook to eval and generate * Make scales max of scale group if block_size = (1, 1) * generalize block_size to more than 2d * add codebook section to README * add greedy init to means * change codes casting condition * Update __init__.py * Add tests * add TODO * make multiplication inplace * store codebook and scales in input_tensor.dtype instead of float32 * update tests * remove torch.allclose check
1 parent 22c4988 commit f4af459

File tree

7 files changed

+827
-0
lines changed

7 files changed

+827
-0
lines changed

test/prototype/test_codebook_quant.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import unittest
2+
3+
import torch
4+
5+
from torchao.prototype.quantization.codebook import (
6+
CodebookQuantizedTensor,
7+
choose_qparams_codebook,
8+
)
9+
from torchao.quantization.utils import compute_error
10+
11+
12+
class TestCodebookQuantization(unittest.TestCase):
13+
def setUp(self):
14+
torch.manual_seed(123)
15+
self.input = torch.randn(100, 256, dtype=torch.float32)
16+
self.block_size = (1, 1)
17+
self.scale_block_size = 64
18+
self.code_dtype = torch.uint8
19+
self.chunk_size = 1024
20+
21+
def test_choose_qparams_codebook(self):
22+
codebook, scales = choose_qparams_codebook(
23+
self.input,
24+
block_size=self.block_size,
25+
scale_block_size=self.scale_block_size,
26+
code_dtype=self.code_dtype,
27+
)
28+
self.assertEqual(codebook.dim(), len(self.block_size) + 1)
29+
30+
self.assertFalse(torch.isnan(codebook).any())
31+
self.assertFalse(torch.isnan(scales).any())
32+
33+
def test_codebook_quantized_tensor_from_float(self):
34+
cqt = CodebookQuantizedTensor.from_float(
35+
self.input,
36+
block_size=self.block_size,
37+
code_dtype=self.code_dtype,
38+
scale_block_size=self.scale_block_size,
39+
chunk_size=self.chunk_size,
40+
)
41+
42+
dequant = cqt.dequantize()
43+
44+
sqnr = compute_error(dequant, self.input)
45+
self.assertGreater(sqnr, 30)
46+
47+
def test_codebook_quantized_tensor_from_float2(self):
48+
block_size = (1, 16)
49+
code_dtype = torch.int32
50+
scale_block_size = self.input.shape[1]
51+
52+
cqt = CodebookQuantizedTensor.from_float(
53+
self.input,
54+
block_size=block_size,
55+
code_dtype=code_dtype,
56+
scale_block_size=scale_block_size,
57+
chunk_size=self.chunk_size,
58+
)
59+
60+
dequant = cqt.dequantize()
61+
62+
sqnr = compute_error(dequant, self.input)
63+
self.assertGreater(sqnr, 30)
64+
65+
66+
if __name__ == "__main__":
67+
unittest.main()

torchao/_models/llama/eval.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,10 @@ def run_evaluation(
194194
)
195195
model.to(device)
196196
model.reset_caches()
197+
if "codebook" in quantization:
198+
from torchao.prototype.quantization.codebook import codebook_weight_only
199+
model.to(device)
200+
quantize_(model, codebook_weight_only(dtype=torch.uint4, scale_block_size=64))
197201

198202
if compile:
199203
model = torch.compile(model, mode="max-autotune", fullgraph=True)

torchao/_models/llama/generate.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,10 @@ def ffn_or_attn_only(mod, fqn):
711711

712712
# do autoquantization
713713
model.finalize_autoquant()
714+
elif "codebook" in quantization:
715+
from torchao.prototype.quantization.codebook import codebook_weight_only
716+
model.to(device)
717+
quantize_(model, codebook_weight_only(dtype=torch.uint4, scale_block_size=64))
714718

715719
else:
716720
if not TORCH_VERSION_AT_LEAST_2_5:
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from .codebook_ops import (
2+
choose_qparams_codebook,
3+
dequantize_codebook,
4+
quantize_codebook,
5+
)
6+
from .codebook_quantized_tensor import CodebookQuantizedTensor, codebook_weight_only
7+
8+
__all__ = [
9+
"CodebookQuantizedTensor",
10+
"codebook_weight_only",
11+
"quantize_codebook",
12+
"dequantize_codebook",
13+
"choose_qparams_codebook",
14+
]

0 commit comments

Comments
 (0)