Skip to content

Feedback on quantize() API #384

Closed
Closed
@gau-nernst

Description

@gau-nernst

Previously we do this

from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors

model = torch.compile(model, mode="max-autotune", fullgraph=True)
change_linear_weights_to_int8_woqtensors(model)
# compile after quant also works

With the new quantization API, we have to do this

from torchao.quantization.quant_api import quantize, int8wo, unwrap_tensor_subclass

model = quantize(model, int8wo())  # or "int8_weight_only"
model = unwrap_tensor_subclass(model)
model = torch.compile(model, mode='max-autotune', fullgraph=True)  # must compile after unwrap

I think the new API is less user-friendly than the previous one.

  1. int8wo(), int4wo() is a bit unintuitive. I understand it is a mechanism to pass params like group size to the quantization. Alternatives: full-blown class with __call__() method e.g. Int8WeightOnlyConfig (kinda verbose, but intention is clear); just pass quant params as extra args/kwargs e.g. quantize("int4wo", groupsize=128)
  2. It's not clear what unwrap_tensor_subclass() does. Also, why do we need it now to compile the model, but not previously?

@jerryzh168

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions