-
Notifications
You must be signed in to change notification settings - Fork 274
[WIP] Codebook quantization flow #1299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1299
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit b938a7c with merge base 46b8796 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
thanks for the contribution! yeah "> 1hour" seems a bit too slow, any ideas to speedup? |
also after this is done, it would useful if you can add codebookquant to generate.py ( ao/torchao/_models/llama/generate.py Line 209 in b714026
ao/torchao/_models/llama/eval.py Line 71 in b714026
|
I think
|
changed max_iter from 1000 to 200 added codebook to eval and generate |
thanks, why the perplexity is so high? we get around 12/13 when using int4wo-64 on llama2: https://github.com/pytorch/ao/tree/main/torchao/quantization#cuda-backend |
I found that it was because I was setting scales to the norm of each scale group instead of max of each scale group. If you set scales to be max of each scale group wikitext perplexity is ~11.6 for |
Thanks if both performance and accuracy are reasonable, I think the main thing is to add a section for codebook quant: https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#uintx-quantization and show the perplexity and token/s result from eval.py and generate.py |
else: | ||
codebook_size = _DTYPE_TO_QVALUE_BOUNDS[code_dtype][1] + 1 | ||
|
||
out_block_size, in_block_size = block_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
block_size
is a general arg that allows people to do all kinds of granularities:
ao/torchao/quantization/quant_primitives.py
Lines 277 to 287 in 53d2486
Note: | |
How can block_size represent different granularities? | |
let's say we have a Tensor of size: (3, 3, 10, 10), here is the table showing how block_size represents different | |
granularities: | |
granularity type | block_size | |
per_tensor | (3, 3, 10, 10) | |
per_axis (axis=0) | (1, 3, 10, 10) | |
per_axis (axis=1) | (3, 1, 10, 10) | |
per_group (groupsize=2) | (3, 3, 10, 2) | |
per_group (groupsize=2) for axis = 3 | (3, 3, 2, 10) |
ao/torchao/quantization/quant_primitives.py
Lines 367 to 375 in 53d2486
shape_for_reduction, reduction_dims = _get_reduction_params( | |
block_size, input.size() | |
) | |
original_shape = input.shape | |
input = input.view(shape_for_reduction) | |
shape_after_reduction = shape_for_reduction | |
for i in reduction_dims: | |
shape_after_reduction[i] = 1 | |
scale = scale.view(shape_after_reduction) |
is there a reason why it's assumed to be 2d here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it's fine we start with 2d for now as well, but would be good to add an assert and create an issue for further development
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made it 2d because that is how it was implemented in AQLM code. The code will be a little more complicated, but it should be possible to generalize to more than 2d.
benchmarks were run on a single NVIDIA-A6000 GPU.
Seems slow, might need custom kernels |
else: | ||
codes = self.codes | ||
if codes.dtype == torch.uint8: | ||
codes = codes.to(torch.int32) # not sure how to index with uint8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this needed? this might be the reason why it's slow, what do you mean by index?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so. I do indexing in dequant = codebook[codes]
in dequantize_codebook
. I got an error when I tried doing codebook[codes]
when codes was uint8.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the error? might be easy to support this in pytorch I feel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only int and long are supported right now: https://github.com/pytorch/pytorch/blob/6e203ae6deaceb370e497bd50f2d02e894f5e9cc/aten/src/ATen/Dispatch.h#L801
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
IndexError: The shape of the mask [2] at index 0 does not match the shape of the indexed tensor [3, 3] at index 0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, let's add a TODO here to follow up and land this for now, is perplexity number expected here? it looks like it's slightly worse than int4wo: https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#cuda-backend
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think perplexity is expected to be slightly worse than asymmetric int4wo (with scales and zero points) because I only add scales. But it should usually be better than symmetric int4wo (only scales, no zero points), though I initialize the centroids in kmeans randomly so it could be worse if centroids have bad initialization.
1 - Fix an extraneous skip end that is out of order with a skip begin. 2 - fix some typos PS: This might cause some README tests to fail, as they have not been run in a long time.
I think model level benchmark is fine for now, maybe add some unittest to test some basic functionality like getting codebook, quantize, dequantize before landing? |
also for model level, is it possible to repro some accuracy result from AQLM: https://github.com/Vahe1994/AQLM/tree/main |
I tested if I could convert an AQLM quantize model to my implementation as a sanity check and it seems like it works https://gist.github.com/DerekLiu35/c1cb9594c515e92c64762cbc8d087f7a. |
Thanks! So the representation can be verified, but we are not sure how to repro the accuracy, and in original AQLM they'd need full model finetuning to get a good accuracy, seems like a good next step, if you are interested in improving this further |
|
||
dequant = cqt.dequantize() | ||
|
||
torch.testing.assert_close(dequant, self.input, atol=2, rtol=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like a large drop, we could use a larger dtype, e.g. torch.uint8, for test I think, so we can get something closer
mse = torch.mean((dequant - self.input) ** 2).item() | ||
self.assertLess(mse, 0.01) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we typically use SQNR:
ao/torchao/quantization/utils.py
Line 50 in 46b8796
def compute_error(x, y): |
Yeah, I'd definitely be interested in trying to implement AQLM to improve accuracy. |
|
||
dequant = cqt.dequantize() | ||
|
||
torch.testing.assert_close(dequant, self.input, atol=0.1, rtol=0.1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it's fine to just rely on sqnr for error checking btw.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution and addressing all the comments!
* Add codebook_ops * Add codebook_quanized_tensor * Add __init__.py * Fix uint8 indexing * Add codebook_weight_only * add codebook to eval and generate * Make scales max of scale group if block_size = (1, 1) * generalize block_size to more than 2d * add codebook section to README * add greedy init to means * change codes casting condition * Update __init__.py * Add tests * add TODO * make multiplication inplace * store codebook and scales in input_tensor.dtype instead of float32 * update tests * remove torch.allclose check
@implements_torch_function(torch.Tensor.detach) | ||
def function_detach(tensor, *args, **kwargs): | ||
return tensor.detach() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To overload linear:
@implements_torch_function(nn.functional.linear)
def function_linear(tensor, *args, **kwargs):
breakpoint()
// torch.ops.torchao.tinygemm(args)
return tensor.detach()
This PR adds codebook quantization flow per #1195
Usage
ToDo
fit_kmeans
faster. Right now it takes >1 hour if you try to quantize a 1B model.