Skip to content

Commit 009f55f

Browse files
authored
Add layout option to woq int4 api (#670)
* feat: add layout option to woq int4 api * chore: update tests * chore: move imports to top of the file
1 parent 174e630 commit 009f55f

File tree

2 files changed

+16
-30
lines changed

2 files changed

+16
-30
lines changed

test/integration/test_integration.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torchao.quantization.dynamic_quant import (
2020
DynamicallyPerAxisQuantizedLinear,
2121
)
22+
from torchao.dtypes import TensorCoreTiledLayoutType
2223
from torchao.quantization.quant_api import (
2324
int4_weight_only,
2425
int8_weight_only,
@@ -852,18 +853,20 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
852853
for test_shape in ([(256, 256, 16)] + ([(256, 256, 8)] if device=='cuda' else [])):
853854
for groupsize in [64, 32]:
854855
for inner_k_tiles in [4, 2]:
855-
kwargs = {"groupsize": groupsize, "inner_k_tiles": inner_k_tiles}
856+
kwargs = {"groupsize": groupsize, "layout_type": TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles)}
856857

857858
def api(mod):
859+
kwargs_copy = kwargs.copy()
858860
if TORCH_VERSION_AFTER_2_4:
859-
kwargs_copy = kwargs.copy()
860861
kwargs_copy["group_size"] = groupsize
861862
del kwargs_copy["groupsize"]
862863
quantize_(mod, int4_weight_only(**kwargs_copy))
863864
if not TORCH_VERSION_AFTER_2_5:
864865
unwrap_tensor_subclass(mod)
865866
else:
866-
change_linear_weights_to_int4_woqtensors(mod, **kwargs)
867+
kwargs_copy["inner_k_tiles"] = inner_k_tiles
868+
del kwargs_copy["layout_type"]
869+
change_linear_weights_to_int4_woqtensors(mod, **kwargs_copy)
867870

868871
self._test_lin_weight_subclass_api_impl(
869872
api,

torchao/quantization/quant_api.py

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,14 @@
2121
import torch.nn.functional as F
2222
from typing import Any, Callable, Union, Dict, Optional
2323

24-
from torchao.dtypes import PlainLayoutType
24+
from torchao.dtypes.uintx.Uintx import UintxLayoutType
25+
from torchao.dtypes import (
26+
to_affine_quantized,
27+
TensorCoreTiledLayoutType,
28+
PlainLayoutType,
29+
AffineQuantizedTensor,
30+
SemiSparseLayoutType
31+
)
2532
from torchao.utils import (
2633
TORCH_VERSION_AFTER_2_4,
2734
unwrap_tensor_subclass,
@@ -182,9 +189,6 @@ def _replace_with_custom_fn_if_matches_filter(
182189

183190

184191
def _is_linear(mod, *args):
185-
# avoid circular dep
186-
from torchao.dtypes import AffineQuantizedTensor
187-
188192
# adding weight tensor subclass isinstance check to make sure the weight is only quantized once
189193
# when it is shared by multiple linear modules
190194
return (
@@ -328,9 +332,6 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
328332
)
329333

330334
def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor:
331-
# avoid circular dep
332-
from torchao.dtypes import to_affine_quantized
333-
334335
mapping_type = MappingType.ASYMMETRIC
335336
target_dtype = torch.int8
336337
return to_affine_quantized(x, mapping_type, _get_per_token_block_size(x), target_dtype)
@@ -339,9 +340,6 @@ def apply_int8_dynamic_activation_int4_weight_quant(weight, group_size=32):
339340
if weight.shape[-1] % group_size != 0:
340341
return weight
341342

342-
# avoid circular dep
343-
from torchao.dtypes import to_affine_quantized
344-
345343
# weight settings
346344
mapping_type = MappingType.SYMMETRIC
347345
block_size = (1, group_size)
@@ -373,7 +371,7 @@ def insert_subclass(lin):
373371
return insert_subclass
374372

375373

376-
def int4_weight_only(group_size=128, inner_k_tiles=8):
374+
def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8)):
377375
"""
378376
Applies uint4 weight-only asymmetric per-group quantization to linear layers, using
379377
"tensor_core_tiled" layout for speedup with tinygemm kernel
@@ -389,16 +387,12 @@ def int4_weight_only(group_size=128, inner_k_tiles=8):
389387
Args:
390388
`group_size`: parameter for quantization, controls the granularity of quantization, smaller
391389
size is more fine grained, choices are [256, 128, 64, 32]
392-
`inner_k_tiles`: parameter for int4 mm kernel, choices are [8, 4, 2]
390+
`layout_type`: layout type for quantized tensor, default is `TensorCoreTiledLayoutType(inner_k_tiles=8)`
393391
"""
394392
def apply_int4_weight_only_quant(weight):
395393
if weight.shape[-1] % group_size != 0:
396394
return weight
397395

398-
# avoid circular dep
399-
from torchao.dtypes import to_affine_quantized
400-
from torchao.dtypes import TensorCoreTiledLayoutType
401-
402396
mapping_type = MappingType.ASYMMETRIC
403397
block_size = (1, group_size)
404398
target_dtype = torch.int32
@@ -408,7 +402,6 @@ def apply_int4_weight_only_quant(weight):
408402
preserve_zero = False
409403
zero_point_dtype = torch.bfloat16
410404
zero_point_domain = ZeroPointDomain.FLOAT
411-
layout_type = TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles)
412405
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type)
413406

414407
return _get_linear_subclass_inserter(apply_int4_weight_only_quant)
@@ -419,9 +412,6 @@ def int8_weight_only():
419412
Applies int8 weight-only symmetric per-channel quantization to linear layers.
420413
"""
421414
def apply_int8wo_quant(weight):
422-
# avoid circular dep
423-
from torchao.dtypes import to_affine_quantized
424-
425415
mapping_type = MappingType.SYMMETRIC
426416
target_dtype = torch.int8
427417
eps = torch.finfo(torch.float32).eps
@@ -432,8 +422,6 @@ def apply_int8wo_quant(weight):
432422
return _get_linear_subclass_inserter(apply_int8wo_quant)
433423

434424
def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor:
435-
# avoid circular dep
436-
from torchao.dtypes import to_affine_quantized
437425
mapping_type = MappingType.SYMMETRIC
438426
target_dtype = torch.int8
439427
eps = 1e-5
@@ -453,8 +441,6 @@ def apply_int8_dynamic_activation_int8_weight_quant(weight):
453441
if in_features <= 16:
454442
return weight
455443

456-
# avoid circular dep
457-
from torchao.dtypes import to_affine_quantized
458444
# weight settings
459445
mapping_type = MappingType.SYMMETRIC
460446
def get_weight_block_size(x):
@@ -479,7 +465,6 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
479465
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
480466
quantization + 2:4 sparsity to linear layers.
481467
"""
482-
from torchao.dtypes import SemiSparseLayoutType
483468
return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType())
484469

485470

@@ -495,8 +480,6 @@ def uintx_weight_only(bit_width, group_size=64, pack_dim=-1):
495480
quantize_affine,
496481
dequantize_affine,
497482
)
498-
from torchao.dtypes.uintx.Uintx import UintxLayoutType
499-
from torchao.dtypes import to_affine_quantized
500483
from torchao.quantization.quant_api import _get_linear_subclass_inserter
501484
def apply_uintx_weight_only_quant(weight):
502485

0 commit comments

Comments
 (0)