diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 7c8f53be2c..dede4e9707 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -193,7 +193,8 @@ def test_int8_mixed_precision_training(self, compile, config, module_swap): linear = nn.Linear(embed_dim, embed_dim, device=device) linear_int8mp = copy.deepcopy(linear) - apply_func = int8_mixed_precision_training(config, module_swap=module_swap) + config.module_swap = module_swap + apply_func = int8_mixed_precision_training(config) quantize_(linear_int8mp, apply_func, set_inductor_config=False) if compile: diff --git a/torchao/prototype/quantized_training/bitnet.py b/torchao/prototype/quantized_training/bitnet.py index 10c030ded1..5d1a0a539f 100644 --- a/torchao/prototype/quantized_training/bitnet.py +++ b/torchao/prototype/quantized_training/bitnet.py @@ -12,7 +12,10 @@ from torch.distributed._tensor import DTensor from torch.utils._triton import has_triton -from torchao.quantization.quant_api import _get_linear_subclass_inserter +from torchao.core.config import AOBaseConfig +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) from torchao.utils import TorchAOBaseTensor from .int8 import quantize_int8_rowwise @@ -232,10 +235,22 @@ def backward(ctx, grad_output): return grad_input, grad_weight, grad_bias -def bitnet_training(): - return _get_linear_subclass_inserter( - BitNetTrainingLinearWeight, allow_requires_grad=True - ) +class BitNetTrainingConfig(AOBaseConfig): + pass + + +# for bc +bitnet_training = BitNetTrainingConfig + + +@register_quantize_module_handler(BitNetTrainingConfig) +def _bitnet_training_transform( + module: torch.nn.Module, + config: BitNetTrainingConfig, +) -> torch.nn.Module: + new_weight = BitNetTrainingLinearWeight(module.weight) + module.weight = torch.nn.Parameter(new_weight, requires_grad=True) + return module def _pack_i2_in_i8(x: Tensor): diff --git a/torchao/prototype/quantized_training/int8.py b/torchao/prototype/quantized_training/int8.py index 94c5043da3..fe6415de11 100644 --- a/torchao/prototype/quantized_training/int8.py +++ b/torchao/prototype/quantized_training/int8.py @@ -4,7 +4,10 @@ from torch import Tensor from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.quantization.quant_api import _get_linear_subclass_inserter +from torchao.core.config import AOBaseConfig +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) from torchao.utils import TorchAOBaseTensor aten = torch.ops.aten @@ -293,7 +296,19 @@ def _(func, types, args, kwargs): return return_and_correct_aliasing(func, args, kwargs, out) -def int8_weight_only_quantized_training(): - return _get_linear_subclass_inserter( - Int8QuantizedTrainingLinearWeight.from_float, allow_requires_grad=True - ) +class Int8WeightOnlyQuantizedTrainingConfig(AOBaseConfig): + pass + + +# for bc +int8_weight_only_quantized_training = Int8WeightOnlyQuantizedTrainingConfig + + +@register_quantize_module_handler(Int8WeightOnlyQuantizedTrainingConfig) +def _int8_weight_only_quantized_training_transform( + module: torch.nn.Module, + config: Int8WeightOnlyQuantizedTrainingConfig, +) -> torch.nn.Module: + new_weight = Int8QuantizedTrainingLinearWeight.from_float(module.weight) + module.weight = torch.nn.Parameter(new_weight, requires_grad=True) + return module diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index 3e7b20a11b..9be21af120 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -1,11 +1,15 @@ -from typing import Any, NamedTuple, Optional, Tuple, Union +from dataclasses import dataclass +from typing import Any, Optional, Tuple, Union import torch import torch.utils._pytree as pytree from torch import Tensor, nn from torch.utils._triton import has_triton -from torchao.quantization.quant_api import _get_linear_subclass_inserter +from torchao.core.config import AOBaseConfig +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) from torchao.utils import TorchAOBaseTensor from .int8 import quantize_int8_rowwise @@ -23,10 +27,16 @@ def scaled_int8_mm( return torch._int_mm(A, B) * col_scale.view(-1) * row_scale.view(-1, 1) -class Int8MixedPrecisionTrainingConfig(NamedTuple): +@dataclass +class Int8MixedPrecisionTrainingConfig(AOBaseConfig): output: bool = True grad_input: bool = True grad_weight: bool = True + module_swap: bool = False + + +# for bc +int8_mixed_precision_training = Int8MixedPrecisionTrainingConfig _DEFAULT_CONFIG = Int8MixedPrecisionTrainingConfig() @@ -265,25 +275,23 @@ def backward(ctx, grad_output): return grad_input, grad_weight, grad_bias, None -def int8_mixed_precision_training( - config: Int8MixedPrecisionTrainingConfig = _DEFAULT_CONFIG, - *, - module_swap: bool = False, +@register_quantize_module_handler(Int8MixedPrecisionTrainingConfig) +def _int8_mixed_precision_training_transform( + module: torch.nn.Module, + config: Int8MixedPrecisionTrainingConfig, ): + module_swap = config.module_swap + # TODO: skip small layers that don't have perf gain. if module_swap: # module swap implementation - def convert_linear(linear: nn.Linear): - linear.__class__ = Int8MixedPrecisionTrainingLinear - linear.config = config - return linear - - return convert_linear + module.__class__ = Int8MixedPrecisionTrainingLinear + module.config = config + return module else: # tensor subclass implementation - return _get_linear_subclass_inserter( - Int8MixedPrecisionTrainingLinearWeight, - config=config, - allow_requires_grad=True, - ) + + new_weight = Int8MixedPrecisionTrainingLinearWeight(module.weight, config) + module.weight = torch.nn.Parameter(new_weight, requires_grad=True) + return module