From 867ec9a71162a29f4fd72cda6e0090232b5f9f90 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Fri, 21 Jun 2024 14:50:42 -0700 Subject: [PATCH] groupsize consistency Summary: half of the apis used groupsize and half used group_size, swapping them all to groupsize Test Plan: python eval.py -q int8wo --limit 1 wikitext: {'word_perplexity,none': 12.204889603121593, 'byte_perplexity,none': 1.5965674184201175, 'bits_per_byte,none': 0.6749734750293632, 'alias': 'wikitext'} python generate.py --quantization int4wo-64 Average tokens/sec: 13.93 Average Bandwidth: 52.04 GB/s Peak Memory Usage: 15.92 GB Model Size: 3.74 GB Reviewers: Subscribers: Tasks: Tags: --- benchmarks/benchmark_hqq.py | 22 +++++----- benchmarks/dora/bench_utils.py | 2 +- test/dora/test_dora_layer.py | 2 +- test/dtypes/test_affine_quantized.py | 2 +- test/hqq/test_triton_mm.py | 30 ++++++------- test/hqq/test_triton_qkv_fused.py | 24 +++++----- test/integration/test_integration.py | 2 +- test/quantization/test_galore_quant.py | 4 +- test/quantization/test_qat.py | 44 +++++++++---------- test/quantization/test_quant_api.py | 12 ++--- test/quantization/test_quant_primitives.py | 6 +-- torchao/_models/llama/eval.py | 8 ++-- torchao/_models/llama/generate.py | 14 +++--- torchao/kernel/intmm_triton.py | 6 +-- torchao/prototype/dora/dora_profile.py | 2 +- torchao/prototype/dora/kernels/common.py | 6 +-- torchao/prototype/dora/kernels/smallk.py | 6 +-- torchao/prototype/fp8/splitk_gemm.py | 6 +-- .../galore/kernels/adam_downproj_fused.py | 6 +-- torchao/prototype/galore/kernels/matmul.py | 6 +-- torchao/prototype/galore/kernels/quant.py | 22 +++++----- torchao/prototype/hqq/README.md | 10 ++--- torchao/prototype/hqq/hqq_tinygemm_linear.py | 6 +-- torchao/prototype/hqq/kernels.py | 8 ++-- torchao/prototype/hqq/mixed_mm.py | 14 +++--- torchao/quantization/README.md | 4 +- torchao/quantization/prototype/qat.py | 10 ++--- torchao/quantization/quant_api.py | 12 ++--- torchao/quantization/utils.py | 6 +-- tutorials/quantize_vit/bfloat16_code.py | 6 +-- tutorials/quantize_vit/quant_code.py | 24 +++++----- 31 files changed, 166 insertions(+), 166 deletions(-) diff --git a/benchmarks/benchmark_hqq.py b/benchmarks/benchmark_hqq.py index 123e2e5f52..e9091410e6 100644 --- a/benchmarks/benchmark_hqq.py +++ b/benchmarks/benchmark_hqq.py @@ -31,7 +31,7 @@ def bench_custom_kernel( W_q, scales, zeros, - group_size, + groupsize, transposed=False, kernel_type="max_autotune", fp8_fast_accum=False, @@ -45,7 +45,7 @@ def fn(): scales.T, zeros.T, transposed=transposed, - group_size=group_size, + groupsize=groupsize, fp8_fast_accum=fp8_fast_accum, kernel_type=kernel_type, ) @@ -65,11 +65,11 @@ def reference_fn(): def run_benchmark( - shape, group_size, dtype, axis=1, transposed=False, quant_dtype=torch.uint8 + shape, groupsize, dtype, axis=1, transposed=False, quant_dtype=torch.uint8 ): qcfg = { **BASE_QUANT_CONFIG, - **dict(group_size=group_size, axis=axis), + **dict(groupsize=groupsize, axis=axis), } M, N, K = shape @@ -103,7 +103,7 @@ def run_benchmark( scales = scales.reshape(N, -1) zeros = zeros.reshape(N, -1) tt_time = bench_custom_kernel( - x, W_q, scales, zeros, group_size, transposed=transposed + x, W_q, scales, zeros, groupsize, transposed=transposed ) should_run_tinygemm = dtype == torch.bfloat16 and not transposed @@ -114,7 +114,7 @@ def run_benchmark( ) int4_time = bench_hqq(x, hqq_int4mm, transposed=transposed, tinygemm=True) - print(f"{shape=}, {group_size=}, {dtype=}, {transposed=}:") + print(f"{shape=}, {groupsize=}, {dtype=}, {transposed=}:") print( f"Ref: {ref_time:.4f}ms", @@ -146,7 +146,7 @@ def run_benchmark( "M", "N", "K", - "group_size", + "groupsize", "dtype", "transposed", "ref", @@ -159,16 +159,16 @@ def run_benchmark( print(torch.cuda.get_device_properties(0)) for shape in SHAPES: - for group_size in GROUP_SIZES: + for groupsize in GROUP_SIZES: for dtype in DTYPES: for transposed in TRANSPOSED: timings = run_benchmark( - shape, group_size, dtype, transposed=transposed + shape, groupsize, dtype, transposed=transposed ) - data.append((*shape, group_size, dtype, transposed, *timings)) + data.append((*shape, groupsize, dtype, transposed, *timings)) output = StringIO() df = pd.DataFrame(data, columns=HEADERS) df.to_csv(output, index=False) print(output.getvalue()) - # df.to_csv("benchmark_hqq_tinygemm.csv", index=False) \ No newline at end of file + # df.to_csv("benchmark_hqq_tinygemm.csv", index=False) diff --git a/benchmarks/dora/bench_utils.py b/benchmarks/dora/bench_utils.py index 2de4fa637a..cc170c24a6 100644 --- a/benchmarks/dora/bench_utils.py +++ b/benchmarks/dora/bench_utils.py @@ -111,7 +111,7 @@ def setup_dora_base_layers(layer_type, in_features, out_features, dtype): # HQQ quant_config = BaseQuantizeConfig( nbits=4, - group_size=64, + groupsize=64, quant_zero=False, quant_scale=False, offload_meta=True, diff --git a/test/dora/test_dora_layer.py b/test/dora/test_dora_layer.py index dd38cc8d6b..628fa1cc9d 100644 --- a/test/dora/test_dora_layer.py +++ b/test/dora/test_dora_layer.py @@ -91,7 +91,7 @@ def test_dora_layer( elif model_type == "HQQDoRALinear": quant_config = BaseQuantizeConfig( nbits=4, - group_size=64, + groupsize=64, quant_zero=False, quant_scale=False, offload_meta=True, diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 05e84d5006..506be1b5bc 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -12,7 +12,7 @@ class TestAffineQuantized(TestCase): def test_tensor_core_layout_transpose(self): t = torch.rand(128, 256, dtype=torch.bfloat16, device="cuda") shape = t.shape - apply_int4_weight_only_quant = int4_weight_only(group_size=32) + apply_int4_weight_only_quant = int4_weight_only(groupsize=32) aqt = apply_int4_weight_only_quant(t) aqt_shape = aqt.shape self.assertEqual(aqt_shape, shape) diff --git a/test/hqq/test_triton_mm.py b/test/hqq/test_triton_mm.py index 4684f28221..e7a3737718 100644 --- a/test/hqq/test_triton_mm.py +++ b/test/hqq/test_triton_mm.py @@ -72,16 +72,16 @@ def _arg_to_id(arg): @pytest.mark.parametrize( - "shape, group_size, axis, dtype, transposed, kernel_type", + "shape, groupsize, axis, dtype, transposed, kernel_type", TEST_CONFIGS, ids=_arg_to_id, ) def test_mixed_mm( - shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8 + shape, groupsize, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8 ): qcfg = { **BASE_QUANT_CONFIG, - **dict(group_size=group_size, axis=axis), + **dict(groupsize=groupsize, axis=axis), } M, N, K = shape @@ -117,7 +117,7 @@ def test_mixed_mm( scales.T, zeros.T, transposed=True, - group_size=group_size, + groupsize=groupsize, fp8_fast_accum=False, kernel_type=kernel_type, ) @@ -132,7 +132,7 @@ def test_mixed_mm( scales.T, zeros.T, transposed=False, - group_size=group_size, + groupsize=groupsize, fp8_fast_accum=False, kernel_type=kernel_type, ) @@ -147,7 +147,7 @@ def test_mixed_mm( # Only for debugging kernel without dependency on HQQ and with no autotuning def _test_mixed_mm( shape, - group_size, + groupsize, BLOCK_M, BLOCK_N, BLOCK_K, @@ -159,7 +159,7 @@ def _test_mixed_mm( ): qcfg = { **BASE_QUANT_CONFIG, - **dict(group_size=group_size, axis=axis), + **dict(groupsize=groupsize, axis=axis), } M, N, K = shape @@ -169,9 +169,9 @@ def _test_mixed_mm( quant_config.update({"weight_quant_params": qcfg}) W_q = torch.randint(0, int(2**4), size=(N, K), dtype=quant_dtype, device="cuda") - scales = torch.arange((N * K) // group_size, dtype=dtype, device="cuda")[:, None] + scales = torch.arange((N * K) // groupsize, dtype=dtype, device="cuda")[:, None] zeros = torch.zeros_like(scales) - W_dq = ((W_q.reshape(-1, group_size) - zeros) * scales).reshape(N, K) + W_dq = ((W_q.reshape(-1, groupsize) - zeros) * scales).reshape(N, K) scales = scales.reshape(N, -1) zeros = zeros.reshape(N, -1) @@ -187,7 +187,7 @@ def _test_mixed_mm( scales.T, zeros.T, transposed=True, - group_size=group_size, + groupsize=groupsize, fp8_fast_accum=False, kernel_type=kernel_type, BLOCK_M=BLOCK_M, @@ -205,14 +205,14 @@ def _test_mixed_mm( scales.T, zeros.T, transposed=False, - group_size=group_size, + groupsize=groupsize, fp8_fast_accum=False, kernel_type=kernel_type, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, ) - msg = f"shape={shape}, group_size={group_size}, axis={axis}, dtype={dtype}, transposed={transposed}, kernel_type={kernel_type}, quant_dtype={quant_dtype}" + msg = f"shape={shape}, groupsize={groupsize}, axis={axis}, dtype={dtype}, transposed={transposed}, kernel_type={kernel_type}, quant_dtype={quant_dtype}" check( hqq_out, @@ -229,10 +229,10 @@ def _test_mixed_mm( BLOCK_M, BLOCK_N, BLOCK_K = shape BLOCK_K = K // 2 BLOCK_N = N // 2 - group_size = BLOCK_K + groupsize = BLOCK_K _test_mixed_mm( shape, - group_size=group_size, + groupsize=groupsize, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, @@ -240,7 +240,7 @@ def _test_mixed_mm( ) _test_mixed_mm( shape, - group_size=group_size, + groupsize=groupsize, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, diff --git a/test/hqq/test_triton_qkv_fused.py b/test/hqq/test_triton_qkv_fused.py index eda171a9ca..6797f70933 100644 --- a/test/hqq/test_triton_qkv_fused.py +++ b/test/hqq/test_triton_qkv_fused.py @@ -75,13 +75,13 @@ def fuse_qkv(W_qs, scales, zeros): """ Args: W_qs (list[torch.Tensor]): len 3 list of tensors with shapes Nq x K, Nk x K, Nv x K where Nk == Nv - scales (list[torch.Tensor]): each is N x (K // group_size), with same N requirements per W_qs + scales (list[torch.Tensor]): each is N x (K // groupsize), with same N requirements per W_qs zeros (list[torch.Tensor]): same as scales Returns: qkv (torch.Tensor): (N_qkv x K) where N_qkv = Nq + Nk + Nv - scales (torch.Tensor): (N_qkv x (K // group_size)) - zeros (torch.Tensor): (N_qkv x (K // group_size)) + scales (torch.Tensor): (N_qkv x (K // groupsize)) + zeros (torch.Tensor): (N_qkv x (K // groupsize)) """ qkv = torch.cat(W_qs, dim=0) # Fuse along N fused_scales = torch.cat([s for s in scales], dim=0) @@ -89,28 +89,28 @@ def fuse_qkv(W_qs, scales, zeros): return qkv, fused_scales, fused_zeros -def ref_proj(x, packed_w, scale, zero, group_size, kernel_type, transposed=False): +def ref_proj(x, packed_w, scale, zero, groupsize, kernel_type, transposed=False): return triton_mixed_mm( x, packed_w, scale.T, zero.T, transposed=transposed, - group_size=group_size, + groupsize=group_size, fp8_fast_accum=False, kernel_type=kernel_type, ) @pytest.mark.parametrize( - "q_shape, kv_shape, group_size, axis, dtype, transposed, kernel_type", + "q_shape, kv_shape, groupsize, axis, dtype, transposed, kernel_type", TEST_CONFIGS, ids=_arg_to_id, ) def test_mixed_mm( q_shape, kv_shape, - group_size, + groupsize, axis, dtype, transposed, @@ -136,7 +136,7 @@ def test_mixed_mm( qcfg = { **BASE_QUANT_CONFIG, - **dict(group_size=group_size, axis=axis), + **dict(groupsize=group_size, axis=axis), } quant_config = BaseQuantizeConfig( @@ -172,7 +172,7 @@ def test_mixed_mm( xs = [torch.randn(seqlen, n, dtype=dtype, device=device) for n in Ns] x_fused = torch.cat(xs, dim=1) q_ref, k_ref, v_ref = [ - ref_proj(x, p, s, z, group_size, kernel_type, transposed=True) + ref_proj(x, p, s, z, groupsize, kernel_type, transposed=True) for x, p, s, z in zip(xs, packed_ws, scales, zeros) ] tt_fused = triton_mixed_mm( @@ -181,7 +181,7 @@ def test_mixed_mm( scales_fused.T, zeros_fused.T, transposed=True, - group_size=group_size, + groupsize=group_size, fp8_fast_accum=False, kernel_type=kernel_type, ) @@ -191,7 +191,7 @@ def test_mixed_mm( x = torch.randn(seqlen, K, dtype=dtype, device=device) q_ref, k_ref, v_ref = [ - ref_proj(x, p, s, z, group_size, kernel_type) + ref_proj(x, p, s, z, groupsize, kernel_type) for p, s, z in zip(packed_ws, scales, zeros) ] @@ -201,7 +201,7 @@ def test_mixed_mm( scales_fused.T, zeros_fused.T, transposed=False, - group_size=group_size, + groupsize=group_size, fp8_fast_accum=False, kernel_type=kernel_type, ) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index b4fbcb152a..a1981cd433 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -833,7 +833,7 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): def api(mod): if TORCH_VERSION_AFTER_2_4: kwargs_copy = kwargs.copy() - kwargs_copy["group_size"] = groupsize + kwargs_copy["groupsize"] = groupsize del kwargs_copy["groupsize"] quantize(mod, int4_weight_only(**kwargs_copy)) unwrap_tensor_subclass(mod) diff --git a/test/quantization/test_galore_quant.py b/test/quantization/test_galore_quant.py index 1eabf479ce..c97e38f0b9 100644 --- a/test/quantization/test_galore_quant.py +++ b/test/quantization/test_galore_quant.py @@ -42,7 +42,7 @@ def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize): bnb_norm = (g.reshape(-1, blocksize) / qstate.absmax[:, None]).reshape(g.shape) tt_q, tt_norm, tt_absmax = triton_quantize_blockwise( - g, qmap, group_size=blocksize, return_normalized=True + g, qmap, groupsize=blocksize, return_normalized=True ) tt_check = torch.allclose(ref_bnb, tt_q) @@ -87,5 +87,5 @@ def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize): q, qstate = F.quantize_blockwise(g, code=qmap, blocksize=blocksize) dq_ref = F.dequantize_blockwise(q, qstate) - dq = triton_dequant_blockwise(q, qmap, qstate.absmax, group_size=blocksize) + dq = triton_dequant_blockwise(q, qmap, qstate.absmax, groupsize=blocksize) assert torch.allclose(dq, dq_ref) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 433fdcb28a..5807179131 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -63,26 +63,26 @@ def _get_qmin_qmax(self, n_bit: int): def test_fake_quantize_per_channel_group(self): n_bit = 4 (qmin, qmax) = self._get_qmin_qmax(n_bit) - group_size = 128 + groupsize = 128 torch.manual_seed(self.SEED) x = torch.randn(100, 256).requires_grad_() - (s, zp) = get_group_qparams_symmetric(x, n_bit, group_size) + (s, zp) = get_group_qparams_symmetric(x, n_bit, groupsize) zp = zp.to(torch.int32) x2 = copy.deepcopy(x) # fake quant op out = fake_quantize_per_channel_group( - x, s, zp, qmin, qmax, group_size, + x, s, zp, qmin, qmax, groupsize, ) out.sum().backward() # compare against PTQ ops out_ptq = torch.ops.quantized_decomposed.quantize_per_channel_group( - x2, s, zp, qmin, qmax, torch.int8, group_size, + x2, s, zp, qmin, qmax, torch.int8, groupsize, ) out_ptq = torch.ops.quantized_decomposed.dequantize_per_channel_group( - out_ptq, s, zp, qmin, qmax, torch.int8, group_size, torch.float32, + out_ptq, s, zp, qmin, qmax, torch.int8, groupsize, torch.float32, ) torch.testing.assert_close(out, out_ptq, atol=0, rtol=0) @@ -113,7 +113,7 @@ def _set_ptq_weight( self, ptq_linear: "Int8DynActInt4WeightLinear", fp32_weight: torch.Tensor, - group_size: int, + groupsize: int, ): """ Set the weight to the quantized version of the given fp32 weights, @@ -121,9 +121,9 @@ def _set_ptq_weight( """ n_bit = 4 (qmin, qmax) = self._get_qmin_qmax(n_bit) - (s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size) + (s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, groupsize) q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group( - fp32_weight, s, zp, qmin, qmax, torch.int8, group_size, + fp32_weight, s, zp, qmin, qmax, torch.int8, groupsize, ) ptq_linear.weight = q_weight ptq_linear.scales = s @@ -134,17 +134,17 @@ def test_qat_8da4w_linear(self): from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATLinear from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear - group_size = 128 + groupsize = 128 torch.manual_seed(self.SEED) qat_linear = Int8DynActInt4WeightQATLinear( - 256, 688, bias=False, groupsize=group_size, + 256, 688, bias=False, groupsize=groupsize, ) ptq_linear = Int8DynActInt4WeightLinear( - 256, 688, bias=False, groupsize=group_size, + 256, 688, bias=False, groupsize=groupsize, ) # Force the weights to be the same - self._set_ptq_weight(ptq_linear, qat_linear.weight, group_size) + self._set_ptq_weight(ptq_linear, qat_linear.weight, groupsize) # Compare linear values torch.manual_seed(self.SEED) @@ -159,12 +159,12 @@ def test_qat_8da4w_quantizer(self): from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer - group_size = 16 + groupsize = 16 torch.manual_seed(self.SEED) m = M() m2 = copy.deepcopy(m) - qat_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) - ptq_quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size) + qat_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=groupsize) + ptq_quantizer = Int8DynActInt4WeightQuantizer(groupsize=groupsize) qat_model = qat_quantizer.prepare(m) ptq_model = ptq_quantizer.quantize(m2) @@ -195,8 +195,8 @@ def test_qat_8da4w_quantizer_meta_weights(self): with torch.device("meta"): m = M() self.assertTrue(all(v.is_meta for v in m.state_dict().values())) - group_size = 16 - qat_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) + groupsize = 16 + qat_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=groupsize) qat_model = qat_quantizer.prepare(m) self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values())) @@ -211,12 +211,12 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): enable_8da4w_fake_quant, ) - group_size = 16 + groupsize = 16 torch.manual_seed(self.SEED) m = M() m2 = copy.deepcopy(m) m3 = copy.deepcopy(m) - quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) + quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=groupsize) qat_model = quantizer.prepare(m) qat_model.apply(disable_8da4w_fake_quant) self.assertFalse(qat_model.linear1._fake_quant_enabled) @@ -241,7 +241,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): self.assertTrue(qat_model.sub.linear._fake_quant_enabled) # Fake quant should be applied as normal - quantizer2 = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) + quantizer2 = Int8DynActInt4WeightQATQuantizer(groupsize=groupsize) qat_model2 = quantizer2.prepare(m3) qat_model2.linear1.weight = qat_model.linear1.weight qat_model2.linear2.weight = qat_model.linear2.weight @@ -263,11 +263,11 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): disable_8da4w_fake_quant, ) - group_size = 16 + groupsize = 16 torch.manual_seed(self.SEED) m = M() nn_model = copy.deepcopy(m) - quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) + quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=groupsize) qat_model = quantizer.prepare(m) qat_model.apply(disable_8da4w_fake_quant) nn_model.linear1.weight = qat_model.linear1.weight diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index b22a157568..ead7218215 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -499,11 +499,11 @@ def test_eval_wrapper_llama3(self): # TODO: move to a separate test file @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") def test_quantized_tensor_subclass_8da4w(self): - group_size = 32 + groupsize = 32 m = ToyLinearModel().eval() m_copy = copy.deepcopy(m) example_inputs = m.example_inputs() - m = quantize(m, int8_dynamic_activation_int4_weight(group_size=group_size)) + m = quantize(m, int8_dynamic_activation_int4_weight(groupsize=groupsize)) assert isinstance(m.linear1.weight, LinearActQuantizedTensor) assert isinstance(m.linear2.weight, LinearActQuantizedTensor) @@ -514,7 +514,7 @@ def test_quantized_tensor_subclass_8da4w(self): from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear - quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size) + quantizer = Int8DynActInt4WeightQuantizer(groupsize=groupsize) m_copy = quantizer.quantize(m_copy) assert isinstance(m_copy.linear1, Int8DynActInt4WeightLinear) assert isinstance(m_copy.linear2, Int8DynActInt4WeightLinear) @@ -531,13 +531,13 @@ def test_quantized_tensor_subclass_int4(self): m_copy = copy.deepcopy(m) example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda") - group_size = 32 - m = quantize(m, int4_weight_only(group_size=group_size)) + groupsize = 32 + m = quantize(m, int4_weight_only(groupsize=groupsize)) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) # reference - _ref_change_linear_weights_to_int4_woqtensors(m_copy, groupsize=group_size) + _ref_change_linear_weights_to_int4_woqtensors(m_copy, groupsize=groupsize) res = m(*example_inputs) ref = m_copy(*example_inputs) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 0e5388c301..56f51cef61 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -286,14 +286,14 @@ def test_quantize_dequantize_group_sym(self): quantized = quantize_affine(input, block_size, scale, zero_point, dtype) dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32) - group_size = 2 + groupsize = 2 quant_min = -128 quant_max = 127 quantized_ref = torch.ops.quantized_decomposed.quantize_per_channel_group( - input, scale, zero_point, quant_min, quant_max, torch.int8, group_size + input, scale, zero_point, quant_min, quant_max, torch.int8, groupsize ) dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel_group( - quantized_ref, scale, zero_point, quant_min, quant_max, torch.int8, group_size, output_dtype=torch.float32 + quantized_ref, scale, zero_point, quant_min, quant_max, torch.int8, groupsize, output_dtype=torch.float32 ) self.assertTrue(torch.equal(quantized, quantized_ref)) diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 7842c3e66c..8d121e89b5 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -13,7 +13,7 @@ ) from torchao.quantization.quant_api import ( - quantize, int4wo, int8wo, int8da_int8w, unwrap_tensor_subclass + quantize, int4_weight_only, int8_weight_only, int8_dynamic_activation_int8_weight, unwrap_tensor_subclass ) from torchao._models._eval import TransformerEvalWrapper, InputRecorder @@ -60,13 +60,13 @@ def run_evaluation( if quantization: if "int8wo" in quantization: - quantize(model, int8wo()) + quantize(model, int8_weight_only()) if "int8dq" in quantization: - quantize(model, int8da_int8w()) + quantize(model, int8_dynamic_activation_int8_weight()) if "int4wo" in quantization and not "gptq" in quantization: groupsize=int(quantization.split("-")[-1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" - quantize(model, int4wo(groupsize=groupsize)) + quantize(model, int4_weight_only(groupsize=groupsize)) if "int4wo" in quantization and "gptq" in quantization: groupsize=int(quantization.split("-")[-2]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 7f7cfab885..2bb3aa5961 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -189,21 +189,21 @@ def main( if quantization: from torchao.quantization.quant_api import ( quantize, - int8wo, - int8da_int8w, - int4wo, + int8_weight_only, + int8_dynamic_activation_int8_weight, + int4_weight_only, autoquant, unwrap_tensor_subclass ) if "int8wo" in quantization: - quantize(model, int8wo()) + quantize(model, int8_weight_only()) if "int8dq" in quantization: - quantize(model, int8da_int8w()) + quantize(model, int8_dynamic_activation_int8_weight()) if "int4wo" in quantization: groupsize=int(quantization.split("-")[-1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" - quantize(model, int4wo(groupsize=groupsize)) + quantize(model, int4_weight_only(groupsize=groupsize)) if "autoquant" == quantization: model = autoquant(model, manual=True) @@ -339,7 +339,7 @@ def callback(x): parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') - parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') + parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') parser.add_argument("--quantization", type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, autoquant') parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') diff --git a/torchao/kernel/intmm_triton.py b/torchao/kernel/intmm_triton.py index d10dac0abe..25c8c4661b 100644 --- a/torchao/kernel/intmm_triton.py +++ b/torchao/kernel/intmm_triton.py @@ -199,9 +199,9 @@ def scaled_matmul_kernel_with_block_pointers( # re-order program ID for better L2 performance width = GROUP_M * grid_n group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) + groupsize = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % groupsize) + pid_n = (pid % width) // (groupsize) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) diff --git a/torchao/prototype/dora/dora_profile.py b/torchao/prototype/dora/dora_profile.py index bf87769742..044e8dc487 100644 --- a/torchao/prototype/dora/dora_profile.py +++ b/torchao/prototype/dora/dora_profile.py @@ -62,7 +62,7 @@ def run(args): elif args.layer_type == "hqq": quant_config = BaseQuantizeConfig( nbits=4, - group_size=64, + groupsize=64, quant_zero=False, quant_scale=False, offload_meta=True, diff --git a/torchao/prototype/dora/kernels/common.py b/torchao/prototype/dora/kernels/common.py index cd0950d4c0..ada58d76a0 100644 --- a/torchao/prototype/dora/kernels/common.py +++ b/torchao/prototype/dora/kernels/common.py @@ -161,9 +161,9 @@ def swizzle_tile( # re-order program ID for better L2 performance width = GROUP_M * grid_n group_id = pid // width - group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) + groupsize = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % groupsize) + pid_n = (pid % width) // (groupsize) elif SWIZZLE == tl.constexpr(SwizzleType.COLUMN_MAJOR): pid_m = pid % grid_m pid_n = pid // grid_m diff --git a/torchao/prototype/dora/kernels/smallk.py b/torchao/prototype/dora/kernels/smallk.py index fc24ea223f..0a9a6937ce 100644 --- a/torchao/prototype/dora/kernels/smallk.py +++ b/torchao/prototype/dora/kernels/smallk.py @@ -156,9 +156,9 @@ def swizzle_tile( # re-order program ID for better L2 performance width = GROUP_M * grid_n group_id = pid // width - group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) + groupsize = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % groupsize) + pid_n = (pid % width) // (groupsize) else: tl.static_assert(False, "swizzle type not supported") diff --git a/torchao/prototype/fp8/splitk_gemm.py b/torchao/prototype/fp8/splitk_gemm.py index 1efaa731db..4ce7d86b94 100644 --- a/torchao/prototype/fp8/splitk_gemm.py +++ b/torchao/prototype/fp8/splitk_gemm.py @@ -14,10 +14,10 @@ def grouped_launch(pid, width = group_m * grid_n group_id = pid // width - group_size = tl.minimum(grid_m - group_id * group_m, group_m) + groupsize = tl.minimum(grid_m - group_id * group_m, group_m) - pid_m = group_id * group_m + (pid % group_size) - pid_n = (pid % width) // group_size + pid_m = group_id * group_m + (pid % groupsize) + pid_n = (pid % width) // groupsize return pid_m, pid_n diff --git a/torchao/prototype/galore/kernels/adam_downproj_fused.py b/torchao/prototype/galore/kernels/adam_downproj_fused.py index 9049baa782..52ab74a359 100644 --- a/torchao/prototype/galore/kernels/adam_downproj_fused.py +++ b/torchao/prototype/galore/kernels/adam_downproj_fused.py @@ -67,9 +67,9 @@ def _fused_adam_mm_kernel( # re-order program ID for better L2 performance width = GROUP_M * grid_n group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) + groupsize = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % groupsize) + pid_n = (pid % width) // (groupsize) # do matrix multiplication rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) diff --git a/torchao/prototype/galore/kernels/matmul.py b/torchao/prototype/galore/kernels/matmul.py index b183f7ed66..224374da83 100644 --- a/torchao/prototype/galore/kernels/matmul.py +++ b/torchao/prototype/galore/kernels/matmul.py @@ -229,9 +229,9 @@ def _matmul_kernel( # re-order program ID for better L2 performance width = GROUP_M * grid_n group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) + groupsize = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % groupsize) + pid_n = (pid % width) // (groupsize) # do matrix multiplication rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) diff --git a/torchao/prototype/galore/kernels/quant.py b/torchao/prototype/galore/kernels/quant.py index 516b741eab..834ff27a7e 100644 --- a/torchao/prototype/galore/kernels/quant.py +++ b/torchao/prototype/galore/kernels/quant.py @@ -36,7 +36,7 @@ def _dequant_kernel( def triton_dequant_blockwise( - q: torch.Tensor, qmap: torch.Tensor, absmax: torch.Tensor, group_size: int + q: torch.Tensor, qmap: torch.Tensor, absmax: torch.Tensor, groupsize: int ): M, N = q.shape dq = torch.empty_like(q).to(absmax.dtype) @@ -52,8 +52,8 @@ def triton_dequant_blockwise( q.stride(0), q.stride(1), BLOCK_M=1, - BLOCK_N=group_size, - GROUP_SIZE=group_size, + BLOCK_N=groupsize, + GROUP_SIZE=groupsize, ) return dq @@ -99,30 +99,30 @@ def _quantize_blockwise_kernel( q = tl.sum(q, axis=1) tl.store(q_ptr + offsets, q, mask=mask) - # Each block processes one group_size number of elements, hence 1 absmax + # Each block processes one groupsize number of elements, hence 1 absmax tl.store(absmax_ptr + pid, absmax, mask=absmax_mask) if RETURN_NORM: tl.store(norm_ptr + offsets, normalized, mask=mask) -# NOTE: Each block processes one group_size number of elements, hence BLOCK_SIZE = group_size -# where group_size corresponds to the groupwise quantization blocksize +# NOTE: Each block processes one groupsize number of elements, hence BLOCK_SIZE = groupsize +# where groupsize corresponds to the groupwise quantization blocksize def triton_quantize_blockwise( - t: torch.Tensor, code, group_size=2048, return_normalized=False + t: torch.Tensor, code, groupsize=2048, return_normalized=False ): """ Params: t: torch.Tensor, tensor to quantize code: torch.Tensor, quantization codebook for bitsandbytes, output of `bitsandbytes.functional.create_dynamic_map` # absmax: torch.Tensor, absolute max values for each block, if None, will be calculated from the input tensor - group_size: int, groupwise quantization blocksize, default 2048, the hardcoded blocksize for bitsandbytes 8-bit optimizers + groupsize: int, groupwise quantization blocksize, default 2048, the hardcoded blocksize for bitsandbytes 8-bit optimizers return_normalized: bool, if True, will return the normalized tensor, primarily for debugging """ numel = t.numel() q = torch.empty(numel, dtype=torch.uint8, device=t.device) normalized = torch.empty_like(t) if return_normalized else None - num_groups = numel // group_size + num_groups = numel // groupsize abs_max = torch.empty(num_groups, dtype=t.dtype, device="cuda") # Cutoffs for quantization # code corresponds to actual (normalized) quant codes @@ -141,7 +141,7 @@ def triton_quantize_blockwise( assert cutoffs.numel() % 2 == 0 grid = lambda META: (triton.cdiv(t.numel(), META["BLOCK_SIZE"]),) - # assert t.numel() % group_size == 0 + # assert t.numel() % groupsize == 0 _quantize_blockwise_kernel[grid]( t.view(-1), cutoffs, @@ -150,7 +150,7 @@ def triton_quantize_blockwise( normalized.view(-1) if return_normalized else None, numel, NUM_BUCKETS=len(cutoffs), - BLOCK_SIZE=group_size, + BLOCK_SIZE=groupsize, RETURN_NORM=return_normalized, ) return ( diff --git a/torchao/prototype/hqq/README.md b/torchao/prototype/hqq/README.md index 8bf1d34260..9412759342 100644 --- a/torchao/prototype/hqq/README.md +++ b/torchao/prototype/hqq/README.md @@ -9,7 +9,7 @@ The kernel fuses two ops: Tested and benchmarked for `HQQ` but could theoretically be used for any asymmetric quantization scheme. -> **NOTE**: Benchmark below is only indicative of performance on consumer-grade `Ampere` GPUs (`A6000` specifically). When tested on `H100`, the performance is on par / marginally worse than native / compiled `torch`. +> **NOTE**: Benchmark below is only indicative of performance on consumer-grade `Ampere` GPUs (`A6000` specifically). When tested on `H100`, the performance is on par / marginally worse than native / compiled `torch`. > The intended use is thus for fine-tuning / training models on non-datacenter GPUs (`80 <= compute capability < 90`). If interested in optimizing the kernel for other architectures, please drop a note in the CUDA-MODE Discord channel. ### Usage @@ -29,9 +29,9 @@ The pseudocode below explains the expected shapes and dtypes. Also see `test/hqq #The reason we use N x K is to match that shape of the weight for a torch.nn.Linear layer, where N -> out-features, K -> in-features weights = torch.randn(N, K, dtype=torch.float16, device="cuda") -#Perform groupwise asymmetric quantization along axis=1 (in-features). E.g., `scales = Wq.reshape(-1, group_size).max(axis=1)`. +#Perform groupwise asymmetric quantization along axis=1 (in-features). E.g., `scales = Wq.reshape(-1, groupsize).max(axis=1)`. #Wq are `s4 / u4` stored as dtype = torch.int8 / torch.uint8, shape N x K -# scales and zeros are shape (N * K // group_size) +# scales and zeros are shape (N * K // groupsize) Wq, scales, zeros = quantize(weights) #Choose your favorite quantization library #Pack i4 stored as i8 to packed 2xi4 i8. @@ -54,7 +54,7 @@ tt_out = triton_mixed_mm( scales.T, zeros.T, transposed=False, - group_size=group_size, + groupsize=groupsize, fp8_fast_accum=False, ) ``` @@ -72,7 +72,7 @@ tt_out = triton_mixed_mm( Initial benchmarking (on `A6000`) demonstrates promising results, scaling well for compute-bound workloads: -| | M | N | K | group_size | dtype | hqq_ref | triton | tinygemm | +| | M | N | K | groupsize | dtype | hqq_ref | triton | tinygemm | | --- | ---- | ---- | ---- | ---------- | -------------- | ------- | ------ | -------- | | 0 | 16 | 4096 | 4096 | 128 | torch.bfloat16 | 0.2675 | 0.0633 | 0.0382 | | 1 | 32 | 4096 | 4096 | 128 | torch.bfloat16 | 0.2669 | 0.0704 | 0.0649 | diff --git a/torchao/prototype/hqq/hqq_tinygemm_linear.py b/torchao/prototype/hqq/hqq_tinygemm_linear.py index 0c4ae45c61..a14447da4a 100644 --- a/torchao/prototype/hqq/hqq_tinygemm_linear.py +++ b/torchao/prototype/hqq/hqq_tinygemm_linear.py @@ -37,7 +37,7 @@ def __init__( self.del_orig = del_orig weight_quant_params = self.quant_config["weight_quant_params"] - self.groupsize = weight_quant_params["group_size"] + self.groupsize = weight_quant_params["groupsize"] self.nbits = weight_quant_params["nbits"] self.inner_k_tiles = inner_k_tiles self.padding = padding @@ -120,10 +120,10 @@ def quantize( # TODO: move these to utils @torch.no_grad() - def reshape_meta_axis1(self, meta_tensor, new_group_size, shape): + def reshape_meta_axis1(self, meta_tensor, new_groupsize, shape): meta_tensor = meta_tensor.repeat([1, shape[1]]).reshape(shape) meta_tensor = torch.mean( - meta_tensor.reshape([-1, new_group_size]), axis=1, keepdim=True + meta_tensor.reshape([-1, new_groupsize]), axis=1, keepdim=True ) return meta_tensor diff --git a/torchao/prototype/hqq/kernels.py b/torchao/prototype/hqq/kernels.py index 8409fcb68b..92321017d3 100644 --- a/torchao/prototype/hqq/kernels.py +++ b/torchao/prototype/hqq/kernels.py @@ -239,9 +239,9 @@ def _mixed_mm_kernel( width = GROUP_M * grid_n group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // group_size + groupsize = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % groupsize) + pid_n = (pid % width) // groupsize rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M if not DEBUG: @@ -394,4 +394,4 @@ def _mixed_mm_kernel( mixed_mm_kernel_compute_bound = triton.autotune( configs=get_configs_compute_bound(), key=["M", "N", "K"] )(_mixed_mm) -_mixed_mm_debug = _mixed_mm \ No newline at end of file +_mixed_mm_debug = _mixed_mm diff --git a/torchao/prototype/hqq/mixed_mm.py b/torchao/prototype/hqq/mixed_mm.py index 6a933fa98c..9d962989f7 100644 --- a/torchao/prototype/hqq/mixed_mm.py +++ b/torchao/prototype/hqq/mixed_mm.py @@ -38,7 +38,7 @@ def triton_mixed_mm( b, scales, zeros, - group_size, + groupsize, transposed=False, acc_dtype=None, input_precision="ieee", @@ -54,9 +54,9 @@ def triton_mixed_mm( Args: a (torch.Tensor): M x K if not transposed, M x N if transposed b (torch.Tensor): (K // 2) x N, packed such that 2 int4's are packed into 1 uint8 (see pack_2xint4) - scales (torch.Tensor): (num_groups x N), where num_groups = (N * K / group_size) + scales (torch.Tensor): (num_groups x N), where num_groups = (N * K / groupsize) zeros (torch.Tensor): same shape as scales - group_size (torch.Tensor): size of group in groupwise quantization -- MUST be along axis 1 of an N x K matrix + groupsize (torch.Tensor): size of group in groupwise quantization -- MUST be along axis 1 of an N x K matrix transposed (bool, optional): Whether to run a transposed matmul where shapes are (M x N) x (K x N) => (M x K) acc_dtype (_type_, optional): dtype of accumulator. Defaults to None, which corresponds to tl.float32. input_precision (str, optional): Only relevant when dtype of a is torch.float32. Defaults to "ieee". @@ -88,7 +88,7 @@ def triton_mixed_mm( M, K = a.shape N = b.shape[1] if not transposed else b.shape[0] * 2 # assert scales.shape[1] == N if not transposed else scales.shape[0] == N - # assert scales.shape[0] == K // group_size if not transposed else scales.shape[1] == K // group_size + # assert scales.shape[0] == K // groupsize if not transposed else scales.shape[1] == K // groupsize assert scales.dtype == a.dtype assert scales.shape == zeros.shape assert zeros.dtype == a.dtype @@ -132,7 +132,7 @@ def triton_mixed_mm( scales.stride(1), TRANSPOSED=transposed, IS_BFLOAT16=a.dtype == torch.bfloat16, - QGROUP_SIZE=group_size, + QGROUP_SIZE=groupsize, acc_dtype=acc_dtype, input_precision=input_precision, fp8_fast_accum=fp8_fast_accum, @@ -164,11 +164,11 @@ def triton_mixed_mm( EVEN_K=True, TRANSPOSED=transposed, IS_BFLOAT16=a.dtype == torch.bfloat16, - QGROUP_SIZE=group_size, + QGROUP_SIZE=groupsize, acc_dtype=acc_dtype, input_precision=input_precision, fp8_fast_accum=fp8_fast_accum, DEBUG=True, ) - return c \ No newline at end of file + return c diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index a6e95d0bed..653b628930 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -104,8 +104,8 @@ example_inputs = m.example_inputs(dtype=dtype, device="cuda") m_bf16 = torch.compile(m_bf16, mode='max-autotune') # apply int4 weight only quant (compatible with tinygemm int4 weight only quant mm kernel in torchao) -group_size = 32 -m = quantize(m, int4_weight_only(group_size=group_size)) +groupsize = 32 +m = quantize(m, int4_weight_only(groupsize=groupsize)) torch._inductor.config.force_fuse_int_mm_with_mul = True torch._inductor.config.use_mixed_mm = True diff --git a/torchao/quantization/prototype/qat.py b/torchao/quantization/prototype/qat.py index 71b585b15e..6a7bceea48 100644 --- a/torchao/quantization/prototype/qat.py +++ b/torchao/quantization/prototype/qat.py @@ -227,7 +227,7 @@ def backward(ctx, gy): # TODO: move this to core quantized_decomposed_lib.define( "fake_quantize_per_channel_group(Tensor input, Tensor scales, Tensor zero_points, " - "int quant_min, int quant_max, int group_size) -> Tensor" + "int quant_min, int quant_max, int groupsize) -> Tensor" ) @impl(quantized_decomposed_lib, "fake_quantize_per_channel_group", "CompositeImplicitAutograd") @@ -237,12 +237,12 @@ def fake_quantize_per_channel_group( zero_points: torch.Tensor, quant_min: int, quant_max: int, - group_size: int, + groupsize: int, ) -> torch.Tensor: - assert group_size > 1 - assert input.shape[-1] % group_size == 0 + assert groupsize > 1 + assert input.shape[-1] % groupsize == 0 assert input.dim() == 2 - grouped_input = input.reshape(-1, group_size).to(torch.float32) + grouped_input = input.reshape(-1, groupsize).to(torch.float32) scales = scales.reshape(-1, 1) zero_points = zero_points.reshape(-1, 1) fq = _GenericFakeQuantize.apply( diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 3a1516d9b5..d9938b2a35 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -305,13 +305,13 @@ def filter_fn(module, fqn): ) return model -def int8_dynamic_activation_int4_weight(group_size=32): +def int8_dynamic_activation_int4_weight(groupsize=32): """Applies int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear This is used to produce a model for executorch backend, but currently executorch did not support lowering for the quantized model from this flow yet Args: - `group_size`: parameter for quantization, controls the granularity of quantization, smaller + `groupsize`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained """ def apply_int8_dynamic_activation_int4_weight_quant(weight): @@ -320,7 +320,7 @@ def apply_int8_dynamic_activation_int4_weight_quant(weight): # weight settings mapping_type = MappingType.SYMMETRIC - block_size = (1, group_size) + block_size = (1, groupsize) target_dtype = torch.int8 eps = torch.finfo(torch.float32).eps quant_min = -8 @@ -347,13 +347,13 @@ def get_per_token_block_size(x): return apply_int8_dynamic_activation_int4_weight_quant -def int4_weight_only(group_size=128, inner_k_tiles=8): +def int4_weight_only(groupsize=128, inner_k_tiles=8): """ Applies uint4 weight-only asymmetric per-group quantization to linear layers, using "tensor_core_tiled" layout for speedup with tinygemm kernel Args: - `group_size`: parameter for quantization, controls the granularity of quantization, smaller + `groupsize`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [256, 128, 64, 32] `inner_k_tiles`: parameter for int4 mm kernel, choices are [8, 4, 2] """ @@ -362,7 +362,7 @@ def apply_int4_weight_only_quant(weight): from torchao.dtypes import to_affine_quantized mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) + block_size = (1, groupsize) target_dtype = torch.int32 quant_min = 0 quant_max = 15 diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 3e3943c93c..94015e9584 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -417,10 +417,10 @@ def get_group_qparams_symmetric(w, n_bit=4, groupsize=128, precision=torch.float def group_quantize_tensor_symmetric( w, n_bit=4, - group_size=128, + groupsize=128, precision=torch.float32, ): - scales, zeros = get_group_qparams_symmetric(w, n_bit, group_size, precision) + scales, zeros = get_group_qparams_symmetric(w, n_bit, groupsize, precision) n_bit = 4 max_int = 2 ** (n_bit - 1) - 1 min_int = -(2 ** (n_bit - 1)) @@ -428,7 +428,7 @@ def group_quantize_tensor_symmetric( # add torch.int4 to core later from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper w_int8 = _quantized_decomposed_quantize_per_channel_group_wrapper( - w, scales, zeros, min_int, max_int, torch.int8, group_size + w, scales, zeros, min_int, max_int, torch.int8, groupsize ) return w_int8, scales, zeros diff --git a/tutorials/quantize_vit/bfloat16_code.py b/tutorials/quantize_vit/bfloat16_code.py index 2e0ec26206..433afb763c 100644 --- a/tutorials/quantize_vit/bfloat16_code.py +++ b/tutorials/quantize_vit/bfloat16_code.py @@ -693,9 +693,9 @@ def triton_(in_ptr0, arg_A, arg_B, out_ptr0): # re-order program ID for better L2 performance width = GROUP_M * grid_n group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) + groupsize = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % groupsize) + pid_n = (pid % width) // (groupsize) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) diff --git a/tutorials/quantize_vit/quant_code.py b/tutorials/quantize_vit/quant_code.py index e5fa7356e4..3f438f4a3b 100644 --- a/tutorials/quantize_vit/quant_code.py +++ b/tutorials/quantize_vit/quant_code.py @@ -458,9 +458,9 @@ def triton_(arg_A, arg_B, in_ptr2, out_ptr0): # re-order program ID for better L2 performance width = GROUP_M * grid_n group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) + groupsize = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % groupsize) + pid_n = (pid % width) // (groupsize) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) @@ -669,9 +669,9 @@ def triton_(arg_A, arg_B, in_ptr2, out_ptr0): # re-order program ID for better L2 performance width = GROUP_M * grid_n group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) + groupsize = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % groupsize) + pid_n = (pid % width) // (groupsize) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) @@ -862,9 +862,9 @@ def triton_(arg_A, arg_B, in_ptr2, out_ptr0): # re-order program ID for better L2 performance width = GROUP_M * grid_n group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) + groupsize = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % groupsize) + pid_n = (pid % width) // (groupsize) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) @@ -1226,9 +1226,9 @@ def triton_(arg_A, arg_B, in_ptr2, in_ptr3, in_ptr4, out_ptr1): # re-order program ID for better L2 performance width = GROUP_M * grid_n group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) + groupsize = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % groupsize) + pid_n = (pid % width) // (groupsize) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)