Closed
Description
Previously we do this
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
model = torch.compile(model, mode="max-autotune", fullgraph=True)
change_linear_weights_to_int8_woqtensors(model)
# compile after quant also works
With the new quantization API, we have to do this
from torchao.quantization.quant_api import quantize, int8wo, unwrap_tensor_subclass
model = quantize(model, int8wo()) # or "int8_weight_only"
model = unwrap_tensor_subclass(model)
model = torch.compile(model, mode='max-autotune', fullgraph=True) # must compile after unwrap
I think the new API is less user-friendly than the previous one.
int8wo()
,int4wo()
is a bit unintuitive. I understand it is a mechanism to pass params like group size to the quantization. Alternatives: full-blown class with__call__()
method e.g.Int8WeightOnlyConfig
(kinda verbose, but intention is clear); just pass quant params as extra args/kwargs e.g.quantize("int4wo", groupsize=128)
- It's not clear what
unwrap_tensor_subclass()
does. Also, why do we need it now to compile the model, but not previously?- Small doc correction.
unwrap_tensor_subclass()
should be imported fromtorchao.utils
ortorchao.quantization.quant_api
, nottorchao.quantization.utils
(https://github.com/pytorch/ao/tree/main/torchao/quantization)
- Small doc correction.