21
21
import torch .nn .functional as F
22
22
from typing import Any , Callable , Union , Dict , Optional
23
23
24
- from torchao .dtypes import PlainLayoutType
24
+ from torchao .dtypes .uintx .Uintx import UintxLayoutType
25
+ from torchao .dtypes import (
26
+ to_affine_quantized ,
27
+ TensorCoreTiledLayoutType ,
28
+ PlainLayoutType ,
29
+ AffineQuantizedTensor ,
30
+ SemiSparseLayoutType
31
+ )
25
32
from torchao .utils import (
26
33
TORCH_VERSION_AFTER_2_4 ,
27
34
unwrap_tensor_subclass ,
@@ -182,9 +189,6 @@ def _replace_with_custom_fn_if_matches_filter(
182
189
183
190
184
191
def _is_linear (mod , * args ):
185
- # avoid circular dep
186
- from torchao .dtypes import AffineQuantizedTensor
187
-
188
192
# adding weight tensor subclass isinstance check to make sure the weight is only quantized once
189
193
# when it is shared by multiple linear modules
190
194
return (
@@ -328,9 +332,6 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
328
332
)
329
333
330
334
def _int8_asymm_per_token_quant (x : torch .Tensor ) -> torch .Tensor :
331
- # avoid circular dep
332
- from torchao .dtypes import to_affine_quantized
333
-
334
335
mapping_type = MappingType .ASYMMETRIC
335
336
target_dtype = torch .int8
336
337
return to_affine_quantized (x , mapping_type , _get_per_token_block_size (x ), target_dtype )
@@ -339,9 +340,6 @@ def apply_int8_dynamic_activation_int4_weight_quant(weight, group_size=32):
339
340
if weight .shape [- 1 ] % group_size != 0 :
340
341
return weight
341
342
342
- # avoid circular dep
343
- from torchao .dtypes import to_affine_quantized
344
-
345
343
# weight settings
346
344
mapping_type = MappingType .SYMMETRIC
347
345
block_size = (1 , group_size )
@@ -373,7 +371,7 @@ def insert_subclass(lin):
373
371
return insert_subclass
374
372
375
373
376
- def int4_weight_only (group_size = 128 , inner_k_tiles = 8 ):
374
+ def int4_weight_only (group_size = 128 , layout_type = TensorCoreTiledLayoutType ( inner_k_tiles = 8 ) ):
377
375
"""
378
376
Applies uint4 weight-only asymmetric per-group quantization to linear layers, using
379
377
"tensor_core_tiled" layout for speedup with tinygemm kernel
@@ -389,16 +387,12 @@ def int4_weight_only(group_size=128, inner_k_tiles=8):
389
387
Args:
390
388
`group_size`: parameter for quantization, controls the granularity of quantization, smaller
391
389
size is more fine grained, choices are [256, 128, 64, 32]
392
- `inner_k_tiles `: parameter for int4 mm kernel, choices are [8, 4, 2]
390
+ `layout_type `: layout type for quantized tensor, default is `TensorCoreTiledLayoutType(inner_k_tiles=8)`
393
391
"""
394
392
def apply_int4_weight_only_quant (weight ):
395
393
if weight .shape [- 1 ] % group_size != 0 :
396
394
return weight
397
395
398
- # avoid circular dep
399
- from torchao .dtypes import to_affine_quantized
400
- from torchao .dtypes import TensorCoreTiledLayoutType
401
-
402
396
mapping_type = MappingType .ASYMMETRIC
403
397
block_size = (1 , group_size )
404
398
target_dtype = torch .int32
@@ -408,7 +402,6 @@ def apply_int4_weight_only_quant(weight):
408
402
preserve_zero = False
409
403
zero_point_dtype = torch .bfloat16
410
404
zero_point_domain = ZeroPointDomain .FLOAT
411
- layout_type = TensorCoreTiledLayoutType (inner_k_tiles = inner_k_tiles )
412
405
return to_affine_quantized (weight , mapping_type , block_size , target_dtype , quant_min , quant_max , eps , zero_point_dtype = zero_point_dtype , preserve_zero = preserve_zero , zero_point_domain = zero_point_domain , layout_type = layout_type )
413
406
414
407
return _get_linear_subclass_inserter (apply_int4_weight_only_quant )
@@ -419,9 +412,6 @@ def int8_weight_only():
419
412
Applies int8 weight-only symmetric per-channel quantization to linear layers.
420
413
"""
421
414
def apply_int8wo_quant (weight ):
422
- # avoid circular dep
423
- from torchao .dtypes import to_affine_quantized
424
-
425
415
mapping_type = MappingType .SYMMETRIC
426
416
target_dtype = torch .int8
427
417
eps = torch .finfo (torch .float32 ).eps
@@ -432,8 +422,6 @@ def apply_int8wo_quant(weight):
432
422
return _get_linear_subclass_inserter (apply_int8wo_quant )
433
423
434
424
def _int8_symm_per_token_reduced_range_quant (x : torch .Tensor ) -> torch .Tensor :
435
- # avoid circular dep
436
- from torchao .dtypes import to_affine_quantized
437
425
mapping_type = MappingType .SYMMETRIC
438
426
target_dtype = torch .int8
439
427
eps = 1e-5
@@ -453,8 +441,6 @@ def apply_int8_dynamic_activation_int8_weight_quant(weight):
453
441
if in_features <= 16 :
454
442
return weight
455
443
456
- # avoid circular dep
457
- from torchao .dtypes import to_affine_quantized
458
444
# weight settings
459
445
mapping_type = MappingType .SYMMETRIC
460
446
def get_weight_block_size (x ):
@@ -479,7 +465,6 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
479
465
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
480
466
quantization + 2:4 sparsity to linear layers.
481
467
"""
482
- from torchao .dtypes import SemiSparseLayoutType
483
468
return int8_dynamic_activation_int8_weight (layout_type = SemiSparseLayoutType ())
484
469
485
470
@@ -495,8 +480,6 @@ def uintx_weight_only(bit_width, group_size=64, pack_dim=-1):
495
480
quantize_affine ,
496
481
dequantize_affine ,
497
482
)
498
- from torchao .dtypes .uintx .Uintx import UintxLayoutType
499
- from torchao .dtypes import to_affine_quantized
500
483
from torchao .quantization .quant_api import _get_linear_subclass_inserter
501
484
def apply_uintx_weight_only_quant (weight ):
502
485
0 commit comments