Skip to content

Commit 6b43a1c

Browse files
authored
Update quantization to use tensor subclasses (#1403)
1 parent b0895a7 commit 6b43a1c

File tree

2 files changed

+54
-35
lines changed

2 files changed

+54
-35
lines changed

recipes/eleuther_eval.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,15 @@ def _setup_model(
222222
) -> nn.Module:
223223
with training.set_default_dtype(self._dtype), self._device:
224224
model = config.instantiate(model_cfg)
225+
225226
if self._quantization_mode is not None:
226227
model = self._quantizer.quantize(model)
227228
model = model.to(device=self._device, dtype=self._dtype)
228-
229-
model.load_state_dict(model_state_dict)
229+
for k, v in model_state_dict.items():
230+
model_state_dict[k] = v.to(self._device)
231+
model.load_state_dict(model_state_dict, assign=True)
232+
else:
233+
model.load_state_dict(model_state_dict)
230234

231235
# Put model in eval mode.
232236
# Note: This will not disable the dropout applied in SDPA,

torchtune/training/quantization.py

Lines changed: 48 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,23 @@
66

77
from typing import Callable, Optional
88

9+
from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_
10+
from torchao.quantization.prototype.qat import (
11+
disable_8da4w_fake_quant,
12+
enable_8da4w_fake_quant,
13+
Int8DynActInt4WeightQATQuantizer,
14+
)
15+
from torchao.quantization.prototype.qat._module_swap_api import (
16+
disable_8da4w_fake_quant_module_swap,
17+
enable_8da4w_fake_quant_module_swap,
18+
Int8DynActInt4WeightQATQuantizerModuleSwap,
19+
)
20+
21+
922
__all__ = [
1023
"get_quantizer_mode",
24+
"Int8DynActInt4WeightQuantizer",
25+
"Int8DynActInt4WeightQATQuantizer",
1126
]
1227

1328

@@ -16,47 +31,47 @@
1631
_quantizer_mode_to_enable_fake_quant = {}
1732

1833

19-
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
34+
# ========================================================
35+
# int8 dynamic activations + int4 weight tensor subclass |
36+
# ========================================================
2037

21-
__all__.append("Int8DynActInt4WeightQuantizer")
22-
_quantizer_to_mode[Int8DynActInt4WeightQuantizer] = "8da4w"
2338

39+
class Int8DynActInt4WeightQuantizer:
40+
"""
41+
Quantizer for applying int8 per token dynamic activation + int4
42+
per group weight quantization to linear layers in the model.
43+
"""
44+
45+
def __init__(self, groupsize: int = 256):
46+
self.groupsize = groupsize
47+
48+
def quantize(self, model):
49+
quantize_fn = int8_dynamic_activation_int4_weight(self.groupsize)
50+
quantize_(model, quantize_fn)
51+
return model
2452

25-
from torchao.quantization.prototype.qat import (
26-
disable_8da4w_fake_quant,
27-
enable_8da4w_fake_quant,
28-
Int8DynActInt4WeightQATQuantizer,
29-
)
3053

31-
__all__.append("Int8DynActInt4WeightQATQuantizer")
54+
_quantizer_to_mode[Int8DynActInt4WeightQuantizer] = "8da4w"
3255
_quantizer_to_mode[Int8DynActInt4WeightQATQuantizer] = "8da4w-qat"
3356
_quantizer_mode_to_disable_fake_quant["8da4w-qat"] = disable_8da4w_fake_quant
3457
_quantizer_mode_to_enable_fake_quant["8da4w-qat"] = enable_8da4w_fake_quant
3558

36-
try:
37-
# Note: QAT tensor subclass implementation in torchao only works
38-
# with FSDP2 today. For other distribution strategies like DDP and
39-
# FSDP1, users will need to fall back to the old module swap flow.
40-
# TODO: remove this try catch once we upgrade to torchao 0.5.0
41-
42-
from torchao.quantization.prototype.qat._module_swap_api import (
43-
disable_8da4w_fake_quant_module_swap,
44-
enable_8da4w_fake_quant_module_swap,
45-
Int8DynActInt4WeightQATQuantizerModuleSwap,
46-
)
47-
48-
__all__.append("Int8DynActInt4WeightQATQuantizerModuleSwap")
49-
_quantizer_to_mode[
50-
Int8DynActInt4WeightQATQuantizerModuleSwap
51-
] = "8da4w-qat-module-swap"
52-
_quantizer_mode_to_disable_fake_quant[
53-
"8da4w-qat-module-swap"
54-
] = disable_8da4w_fake_quant_module_swap
55-
_quantizer_mode_to_enable_fake_quant[
56-
"8da4w-qat-module-swap"
57-
] = enable_8da4w_fake_quant_module_swap
58-
except ImportError:
59-
pass
59+
60+
# ====================================================
61+
# int8 dynamic activations + int4 weight module swap |
62+
# ====================================================
63+
64+
# Note: QAT tensor subclass implementation in torchao only works
65+
# with FSDP2 today. For other distribution strategies like DDP and
66+
# FSDP1, users will need to fall back to the old module swap flow.
67+
__all__.append("Int8DynActInt4WeightQATQuantizerModuleSwap")
68+
_quantizer_to_mode[Int8DynActInt4WeightQATQuantizerModuleSwap] = "8da4w-qat-module-swap"
69+
_quantizer_mode_to_disable_fake_quant[
70+
"8da4w-qat-module-swap"
71+
] = disable_8da4w_fake_quant_module_swap
72+
_quantizer_mode_to_enable_fake_quant[
73+
"8da4w-qat-module-swap"
74+
] = enable_8da4w_fake_quant_module_swap
6075

6176

6277
def get_quantizer_mode(quantizer: Optional[Callable]) -> Optional[str]:

0 commit comments

Comments
 (0)