diff --git a/test/sparsity/test_supermask.py b/test/sparsity/test_supermask.py index fa86850a07..c84a152330 100644 --- a/test/sparsity/test_supermask.py +++ b/test/sparsity/test_supermask.py @@ -6,8 +6,6 @@ from torch import nn from torch.testing._internal import common_utils -from torchao.sparsity import sparsify_ - logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO ) @@ -30,13 +28,10 @@ def test_supermask(self, sparsity_level, blocksize): from torchao.sparsity import SupermaskLinear M, N = model[0].weight.shape - sparsify_( - model, - lambda x: SupermaskLinear.from_linear( - x, sparsity_level=sparsity_level, blocksize=blocksize - ), + model[0] = SupermaskLinear.from_linear( + model[0], sparsity_level=sparsity_level, blocksize=blocksize ) - sparsify_(model, SupermaskLinear.to_linear) + model[0] = SupermaskLinear.to_linear(model[0]) weight_bsr = model[0].weight.to_sparse_bsr(blocksize=blocksize) # Test correct sparsity level diff --git a/torchao/sparsity/README.md b/torchao/sparsity/README.md index b689a3adf4..4d894461ce 100644 --- a/torchao/sparsity/README.md +++ b/torchao/sparsity/README.md @@ -78,11 +78,11 @@ quantize_(model, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())) ### 2:4 sparsity ```py -from torchao.sparsity.sparse_api import sparsify_, semi_sparse_weight +from torchao.sparsity.sparse_api import sparsify_, SemiSparseWeightConfig from torchao.dtypes import SemiSparseLayout model = model.cuda() -sparsify_(model, semi_sparse_weight()) +sparsify_(model, SemiSparseWeightConfig()) ``` ### Block sparsity @@ -90,10 +90,10 @@ We offer prototype support for accelerating block sparsity with our triton kerne ```py from torchao.sparsity.sparse_api import sparsify_ -from torchao.sparsity import block_sparse_weight +from torchao.sparsity import BlockSparseWeightConfig model = model.cuda() -sparsify_(model, block_sparse_weight()) +sparsify_(model, BlockSparseWeightConfig()) ``` # Goal diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index 9e9611e0ad..4b5cf0d8d4 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -1,17 +1,23 @@ -from functools import partial +import types +from dataclasses import dataclass from typing import Callable, Optional import torch from torch.sparse import to_sparse_semi_structured +from torchao.core.config import AOBaseConfig from torchao.prototype.sparsity.sparsifier.weight_norm_sparsifier import ( WeightNormSparsifier, ) from torchao.quantization.quant_api import ( - _get_linear_subclass_inserter, _is_linear, + _linear_extra_repr, _replace_with_custom_fn_if_matches_filter, ) +from torchao.quantization.transform_module import ( + _QUANTIZE_CONFIG_HANDLER, + register_quantize_module_handler, +) from torchao.sparsity.blocksparse import BlockSparseTensor @@ -35,22 +41,53 @@ def apply_fake_sparsity(model, **kwargs): sparsifier.squash_mask() -def block_sparse_weight(blocksize=64): - return _get_linear_subclass_inserter( - partial(BlockSparseTensor.from_dense, blocksize=blocksize) - ) +@dataclass +class BlockSparseWeightConfig(AOBaseConfig): + blocksize: int = 64 + + +# for bc +block_sparse_weight = BlockSparseWeightConfig + + +@register_quantize_module_handler(BlockSparseWeightConfig) +def _block_sparse_weight_transform( + module: torch.nn.Module, + config: BlockSparseWeightConfig, +): + blocksize = config.blocksize + new_weight = BlockSparseTensor.from_dense(module.weight, blocksize) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module -def semi_sparse_weight(): +class SemiSparseWeightConfig(AOBaseConfig): """ - Convert the weight of linear moduels to semi-structured (2:4) sparsity + Configuration for converting the weight of linear modules to semi-structured (2:4) sparsity """ - return _get_linear_subclass_inserter(to_sparse_semi_structured) + + pass + + +# for bc +semi_sparse_weight = SemiSparseWeightConfig + + +@register_quantize_module_handler(SemiSparseWeightConfig) +def _semi_sparse_weight_transform( + module: torch.nn.Module, + config: SemiSparseWeightConfig, +) -> torch.nn.Module: + new_weight = to_sparse_semi_structured(module.weight) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module def sparsify_( model: torch.nn.Module, - apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], + config: AOBaseConfig, filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, ) -> torch.nn.Module: """Convert the weight of linear modules in the model with `apply_tensor_subclass`. @@ -63,8 +100,8 @@ def sparsify_( Args: model (torch.nn.Module): input model - apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (sparsified) tensor subclass instance (e.g. affine quantized tensor instance) - filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on the weight of the module + config (AOBaseConfig): a workflow configuration object + filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to apply the specified workflow to this module. **Example:** :: @@ -85,8 +122,10 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: from torchao.dtypes import SemiSparseLayout m = quantize_(m, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout), filter_fn) """ + handler = _QUANTIZE_CONFIG_HANDLER[type(config)] _replace_with_custom_fn_if_matches_filter( model, - apply_tensor_subclass, + handler, _is_linear if filter_fn is None else filter_fn, + extra_args=(config,), )