Skip to content

[WIP] Codebook quantization flow #1299

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Dec 17, 2024
Merged

[WIP] Codebook quantization flow #1299

merged 23 commits into from
Dec 17, 2024

Conversation

DerekLiu35
Copy link
Contributor

This PR adds codebook quantization flow per #1195

Usage

import torch
from torchao.prototype.quantization.codebook.codebook_quantized_tensor import CodebookQuantizedTensor

input_tensor = torch.randn(1024, 1024,  device='cuda')

block_size = (1, 1)
code_dtype = torch.uint4

quantized_tensor = CodebookQuantizedTensor.from_float(input_tensor, block_size, code_dtype)

dequantized_tensor = quantized_tensor.dequantize()

ToDo

  • make fit_kmeans faster. Right now it takes >1 hour if you try to quantize a 1B model.

Sorry, something went wrong.

Verified

This commit was signed with the committer’s verified signature.
Copy link

pytorch-bot bot commented Nov 16, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1299

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit b938a7c with merge base 46b8796 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 16, 2024
@jerryzh168
Copy link
Contributor

thanks for the contribution! yeah "> 1hour" seems a bit too slow, any ideas to speedup?

@jerryzh168
Copy link
Contributor

also after this is done, it would useful if you can add codebookquant to generate.py (

if quantization:
) and eval (
if quantization:
) to test the e2e model performance and accuracy

@DerekLiu35
Copy link
Contributor Author

thanks for the contribution! yeah "> 1hour" seems a bit too slow, any ideas to speedup?

I think

  • For block_size = (1, 1), It's similar to nf4tensor, so we can use absolute distance for scalars instead of euclidean distance
  • We could also decrease max_iter from 1000 to 200 for fit_kmeans but this would increase quantization error.

DerekLiu35 and others added 4 commits November 30, 2024 20:45
@DerekLiu35
Copy link
Contributor Author

DerekLiu35 commented Dec 1, 2024

changed max_iter from 1000 to 200 added codebook to eval and generate
I also added scales with group size 64 since when I ran eval codebook with torch.uint4 it was getting high perplexity, the perplexity is still ~200 even with scales with group size 64

@jerryzh168
Copy link
Contributor

jerryzh168 commented Dec 2, 2024

changed max_iter from 1000 to 200 added codebook to eval and generate I also added scales with group size 64 since when I ran eval codebook with torch.uint4 it was getting high perplexity, the perplexity is still ~200 even with scales with group size 64

thanks, why the perplexity is so high? we get around 12/13 when using int4wo-64 on llama2: https://github.com/pytorch/ao/tree/main/torchao/quantization#cuda-backend

DerekLiu35 and others added 2 commits December 4, 2024 07:28
@DerekLiu35
Copy link
Contributor Author

DerekLiu35 commented Dec 4, 2024

thanks, why the perplexity is so high? we get around 12/13 when using int4wo-64 on llama2: https://github.com/pytorch/ao/tree/main/torchao/quantization#cuda-backend

I found that it was because I was setting scales to the norm of each scale group instead of max of each scale group. If you set scales to be max of each scale group wikitext perplexity is ~11.6 for
llama-3.2-3B.
Though, AQLM initializes their scales as the norm of each scale group instead of max, but I guess they only do vector quantization and no scalar quantization.

@jerryzh168
Copy link
Contributor

Thanks if both performance and accuracy are reasonable, I think the main thing is to add a section for codebook quant: https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#uintx-quantization and show the perplexity and token/s result from eval.py and generate.py

else:
codebook_size = _DTYPE_TO_QVALUE_BOUNDS[code_dtype][1] + 1

out_block_size, in_block_size = block_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

block_size is a general arg that allows people to do all kinds of granularities:

Note:
How can block_size represent different granularities?
let's say we have a Tensor of size: (3, 3, 10, 10), here is the table showing how block_size represents different
granularities:
granularity type | block_size
per_tensor | (3, 3, 10, 10)
per_axis (axis=0) | (1, 3, 10, 10)
per_axis (axis=1) | (3, 1, 10, 10)
per_group (groupsize=2) | (3, 3, 10, 2)
per_group (groupsize=2) for axis = 3 | (3, 3, 2, 10)
, you can also refer to
shape_for_reduction, reduction_dims = _get_reduction_params(
block_size, input.size()
)
original_shape = input.shape
input = input.view(shape_for_reduction)
shape_after_reduction = shape_for_reduction
for i in reduction_dims:
shape_after_reduction[i] = 1
scale = scale.view(shape_after_reduction)
for some helper functions that helps to make the shape correct

is there a reason why it's assumed to be 2d here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's fine we start with 2d for now as well, but would be good to add an assert and create an issue for further development

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it 2d because that is how it was implemented in AQLM code. The code will be a little more complicated, but it should be possible to generalize to more than 2d.

DerekLiu35 and others added 3 commits December 6, 2024 16:43
@DerekLiu35
Copy link
Contributor Author

benchmarks were run on a single NVIDIA-A6000 GPU.

Model Technique wikitext-perplexity Tokens/Second Memory Bandwidth (GB/s) Peak Memory (GB) Model Size (GB)
Llama-3-8B Base (bfloat16) 7.590 32.36 485.71 16.19 15.01
codebook-4-64 9.533 1.73 8.62 23.11 4.98
Llama-3.1-8B Base (bfloat16) 7.713 32.16 482.70 16.35 15.01
codebook-4-64 10.095 1.73 8.63 23.11 4.98

Seems slow, might need custom kernels

else:
codes = self.codes
if codes.dtype == torch.uint8:
codes = codes.to(torch.int32) # not sure how to index with uint8
Copy link
Contributor

@jerryzh168 jerryzh168 Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed? this might be the reason why it's slow, what do you mean by index?

Copy link
Contributor Author

@DerekLiu35 DerekLiu35 Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. I do indexing in dequant = codebook[codes] in dequantize_codebook. I got an error when I tried doing codebook[codes] when codes was uint8.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the error? might be easy to support this in pytorch I feel

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
IndexError: The shape of the mask [2] at index 0 does not match the shape of the indexed tensor [3, 3] at index 0.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, let's add a TODO here to follow up and land this for now, is perplexity number expected here? it looks like it's slightly worse than int4wo: https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#cuda-backend

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think perplexity is expected to be slightly worse than asymmetric int4wo (with scales and zero points) because I only add scales. But it should usually be better than symmetric int4wo (only scales, no zero points), though I initialize the centroids in kmeans randomly so it could be worse if centroids have bad initialization.

yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
1 - Fix an extraneous skip end that is out of order with a skip begin.
2 - fix some typos

PS: This might cause some README tests to fail, as they have not been run in a long time.
@jerryzh168
Copy link
Contributor

I think model level benchmark is fine for now, maybe add some unittest to test some basic functionality like getting codebook, quantize, dequantize before landing?

@jerryzh168
Copy link
Contributor

also for model level, is it possible to repro some accuracy result from AQLM: https://github.com/Vahe1994/AQLM/tree/main

DerekLiu35 and others added 2 commits December 13, 2024 18:09
@DerekLiu35
Copy link
Contributor Author

I tested if I could convert an AQLM quantize model to my implementation as a sanity check and it seems like it works https://gist.github.com/DerekLiu35/c1cb9594c515e92c64762cbc8d087f7a.
Though, I get ~300 perplexity for llama-3-8b when I do similar setting as AQLM (block_size=(1, 8), code_dtype=torch.int32 (which makes codebook_size = 2**16), and scale_block_size=input.shape[1]). Maybe the extra tuning is necessary for lower perplexity?

@jerryzh168
Copy link
Contributor

Thanks! So the representation can be verified, but we are not sure how to repro the accuracy, and in original AQLM they'd need full model finetuning to get a good accuracy, seems like a good next step, if you are interested in improving this further


dequant = cqt.dequantize()

torch.testing.assert_close(dequant, self.input, atol=2, rtol=2)
Copy link
Contributor

@jerryzh168 jerryzh168 Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like a large drop, we could use a larger dtype, e.g. torch.uint8, for test I think, so we can get something closer

Comment on lines 65 to 66
mse = torch.mean((dequant - self.input) ** 2).item()
self.assertLess(mse, 0.01)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we typically use SQNR:

def compute_error(x, y):
, numbers greater than 20 or 30 are reasonable

@DerekLiu35
Copy link
Contributor Author

Thanks! So the representation can be verified, but we are not sure how to repro the accuracy, and in original AQLM they'd need full model finetuning to get a good accuracy, seems like a good next step, if you are interested in improving this further

Yeah, I'd definitely be interested in trying to implement AQLM to improve accuracy.


dequant = cqt.dequantize()

torch.testing.assert_close(dequant, self.input, atol=0.1, rtol=0.1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it's fine to just rely on sqnr for error checking btw.

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution and addressing all the comments!

@jerryzh168 jerryzh168 added topic: not user facing Use this tag if you don't want this PR to show up in release notes topic: new feature Use this tag if this PR adds a new feature labels Dec 14, 2024
@jerryzh168 jerryzh168 merged commit bc000aa into pytorch:main Dec 17, 2024
18 of 20 checks passed
amdfaa pushed a commit that referenced this pull request Jan 10, 2025
* Add codebook_ops

* Add codebook_quanized_tensor

* Add __init__.py

* Fix uint8 indexing

* Add codebook_weight_only

* add codebook to eval and generate

* Make scales max of scale group if block_size = (1, 1)

* generalize block_size to more than 2d

* add codebook section to README

* add greedy init to means

* change codes casting condition

* Update __init__.py

* Add tests

* add TODO

* make multiplication inplace

* store codebook and scales in input_tensor.dtype instead of float32

* update tests

* remove torch.allclose check
Comment on lines +247 to +249
@implements_torch_function(torch.Tensor.detach)
def function_detach(tensor, *args, **kwargs):
return tensor.detach()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mostafaelhoushi

To overload linear:

@implements_torch_function(nn.functional.linear)
def function_linear(tensor, *args, **kwargs):
    breakpoint()
    // torch.ops.torchao.tinygemm(args)
    return tensor.detach()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants