Skip to content

groupsize consistency #417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions benchmarks/benchmark_hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def bench_custom_kernel(
W_q,
scales,
zeros,
group_size,
groupsize,
transposed=False,
kernel_type="max_autotune",
fp8_fast_accum=False,
Expand All @@ -45,7 +45,7 @@ def fn():
scales.T,
zeros.T,
transposed=transposed,
group_size=group_size,
groupsize=groupsize,
fp8_fast_accum=fp8_fast_accum,
kernel_type=kernel_type,
)
Expand All @@ -65,11 +65,11 @@ def reference_fn():


def run_benchmark(
shape, group_size, dtype, axis=1, transposed=False, quant_dtype=torch.uint8
shape, groupsize, dtype, axis=1, transposed=False, quant_dtype=torch.uint8
):
qcfg = {
**BASE_QUANT_CONFIG,
**dict(group_size=group_size, axis=axis),
**dict(groupsize=groupsize, axis=axis),
}
M, N, K = shape

Expand Down Expand Up @@ -103,7 +103,7 @@ def run_benchmark(
scales = scales.reshape(N, -1)
zeros = zeros.reshape(N, -1)
tt_time = bench_custom_kernel(
x, W_q, scales, zeros, group_size, transposed=transposed
x, W_q, scales, zeros, groupsize, transposed=transposed
)

should_run_tinygemm = dtype == torch.bfloat16 and not transposed
Expand All @@ -114,7 +114,7 @@ def run_benchmark(
)
int4_time = bench_hqq(x, hqq_int4mm, transposed=transposed, tinygemm=True)

print(f"{shape=}, {group_size=}, {dtype=}, {transposed=}:")
print(f"{shape=}, {groupsize=}, {dtype=}, {transposed=}:")

print(
f"Ref: {ref_time:.4f}ms",
Expand Down Expand Up @@ -146,7 +146,7 @@ def run_benchmark(
"M",
"N",
"K",
"group_size",
"groupsize",
"dtype",
"transposed",
"ref",
Expand All @@ -159,16 +159,16 @@ def run_benchmark(
print(torch.cuda.get_device_properties(0))

for shape in SHAPES:
for group_size in GROUP_SIZES:
for groupsize in GROUP_SIZES:
for dtype in DTYPES:
for transposed in TRANSPOSED:
timings = run_benchmark(
shape, group_size, dtype, transposed=transposed
shape, groupsize, dtype, transposed=transposed
)
data.append((*shape, group_size, dtype, transposed, *timings))
data.append((*shape, groupsize, dtype, transposed, *timings))

output = StringIO()
df = pd.DataFrame(data, columns=HEADERS)
df.to_csv(output, index=False)
print(output.getvalue())
# df.to_csv("benchmark_hqq_tinygemm.csv", index=False)
# df.to_csv("benchmark_hqq_tinygemm.csv", index=False)
2 changes: 1 addition & 1 deletion benchmarks/dora/bench_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def setup_dora_base_layers(layer_type, in_features, out_features, dtype):
# HQQ
quant_config = BaseQuantizeConfig(
nbits=4,
group_size=64,
groupsize=64,
quant_zero=False,
quant_scale=False,
offload_meta=True,
Expand Down
2 changes: 1 addition & 1 deletion test/dora/test_dora_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_dora_layer(
elif model_type == "HQQDoRALinear":
quant_config = BaseQuantizeConfig(
nbits=4,
group_size=64,
groupsize=64,
quant_zero=False,
quant_scale=False,
offload_meta=True,
Expand Down
2 changes: 1 addition & 1 deletion test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class TestAffineQuantized(TestCase):
def test_tensor_core_layout_transpose(self):
t = torch.rand(128, 256, dtype=torch.bfloat16, device="cuda")
shape = t.shape
apply_int4_weight_only_quant = int4_weight_only(group_size=32)
apply_int4_weight_only_quant = int4_weight_only(groupsize=32)
aqt = apply_int4_weight_only_quant(t)
aqt_shape = aqt.shape
self.assertEqual(aqt_shape, shape)
Expand Down
30 changes: 15 additions & 15 deletions test/hqq/test_triton_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,16 @@ def _arg_to_id(arg):


@pytest.mark.parametrize(
"shape, group_size, axis, dtype, transposed, kernel_type",
"shape, groupsize, axis, dtype, transposed, kernel_type",
TEST_CONFIGS,
ids=_arg_to_id,
)
def test_mixed_mm(
shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8
shape, groupsize, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8
):
qcfg = {
**BASE_QUANT_CONFIG,
**dict(group_size=group_size, axis=axis),
**dict(groupsize=groupsize, axis=axis),
}
M, N, K = shape

Expand Down Expand Up @@ -117,7 +117,7 @@ def test_mixed_mm(
scales.T,
zeros.T,
transposed=True,
group_size=group_size,
groupsize=groupsize,
fp8_fast_accum=False,
kernel_type=kernel_type,
)
Expand All @@ -132,7 +132,7 @@ def test_mixed_mm(
scales.T,
zeros.T,
transposed=False,
group_size=group_size,
groupsize=groupsize,
fp8_fast_accum=False,
kernel_type=kernel_type,
)
Expand All @@ -147,7 +147,7 @@ def test_mixed_mm(
# Only for debugging kernel without dependency on HQQ and with no autotuning
def _test_mixed_mm(
shape,
group_size,
groupsize,
BLOCK_M,
BLOCK_N,
BLOCK_K,
Expand All @@ -159,7 +159,7 @@ def _test_mixed_mm(
):
qcfg = {
**BASE_QUANT_CONFIG,
**dict(group_size=group_size, axis=axis),
**dict(groupsize=groupsize, axis=axis),
}
M, N, K = shape

Expand All @@ -169,9 +169,9 @@ def _test_mixed_mm(
quant_config.update({"weight_quant_params": qcfg})
W_q = torch.randint(0, int(2**4), size=(N, K), dtype=quant_dtype, device="cuda")

scales = torch.arange((N * K) // group_size, dtype=dtype, device="cuda")[:, None]
scales = torch.arange((N * K) // groupsize, dtype=dtype, device="cuda")[:, None]
zeros = torch.zeros_like(scales)
W_dq = ((W_q.reshape(-1, group_size) - zeros) * scales).reshape(N, K)
W_dq = ((W_q.reshape(-1, groupsize) - zeros) * scales).reshape(N, K)
scales = scales.reshape(N, -1)
zeros = zeros.reshape(N, -1)

Expand All @@ -187,7 +187,7 @@ def _test_mixed_mm(
scales.T,
zeros.T,
transposed=True,
group_size=group_size,
groupsize=groupsize,
fp8_fast_accum=False,
kernel_type=kernel_type,
BLOCK_M=BLOCK_M,
Expand All @@ -205,14 +205,14 @@ def _test_mixed_mm(
scales.T,
zeros.T,
transposed=False,
group_size=group_size,
groupsize=groupsize,
fp8_fast_accum=False,
kernel_type=kernel_type,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
)
msg = f"shape={shape}, group_size={group_size}, axis={axis}, dtype={dtype}, transposed={transposed}, kernel_type={kernel_type}, quant_dtype={quant_dtype}"
msg = f"shape={shape}, groupsize={groupsize}, axis={axis}, dtype={dtype}, transposed={transposed}, kernel_type={kernel_type}, quant_dtype={quant_dtype}"

check(
hqq_out,
Expand All @@ -229,18 +229,18 @@ def _test_mixed_mm(
BLOCK_M, BLOCK_N, BLOCK_K = shape
BLOCK_K = K // 2
BLOCK_N = N // 2
group_size = BLOCK_K
groupsize = BLOCK_K
_test_mixed_mm(
shape,
group_size=group_size,
groupsize=groupsize,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
transposed=False,
)
_test_mixed_mm(
shape,
group_size=group_size,
groupsize=groupsize,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
Expand Down
24 changes: 12 additions & 12 deletions test/hqq/test_triton_qkv_fused.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,42 +75,42 @@ def fuse_qkv(W_qs, scales, zeros):
"""
Args:
W_qs (list[torch.Tensor]): len 3 list of tensors with shapes Nq x K, Nk x K, Nv x K where Nk == Nv
scales (list[torch.Tensor]): each is N x (K // group_size), with same N requirements per W_qs
scales (list[torch.Tensor]): each is N x (K // groupsize), with same N requirements per W_qs
zeros (list[torch.Tensor]): same as scales

Returns:
qkv (torch.Tensor): (N_qkv x K) where N_qkv = Nq + Nk + Nv
scales (torch.Tensor): (N_qkv x (K // group_size))
zeros (torch.Tensor): (N_qkv x (K // group_size))
scales (torch.Tensor): (N_qkv x (K // groupsize))
zeros (torch.Tensor): (N_qkv x (K // groupsize))
"""
qkv = torch.cat(W_qs, dim=0) # Fuse along N
fused_scales = torch.cat([s for s in scales], dim=0)
fused_zeros = torch.cat([z for z in zeros], dim=0)
return qkv, fused_scales, fused_zeros


def ref_proj(x, packed_w, scale, zero, group_size, kernel_type, transposed=False):
def ref_proj(x, packed_w, scale, zero, groupsize, kernel_type, transposed=False):
return triton_mixed_mm(
x,
packed_w,
scale.T,
zero.T,
transposed=transposed,
group_size=group_size,
groupsize=group_size,
fp8_fast_accum=False,
kernel_type=kernel_type,
)


@pytest.mark.parametrize(
"q_shape, kv_shape, group_size, axis, dtype, transposed, kernel_type",
"q_shape, kv_shape, groupsize, axis, dtype, transposed, kernel_type",
TEST_CONFIGS,
ids=_arg_to_id,
)
def test_mixed_mm(
q_shape,
kv_shape,
group_size,
groupsize,
axis,
dtype,
transposed,
Expand All @@ -136,7 +136,7 @@ def test_mixed_mm(

qcfg = {
**BASE_QUANT_CONFIG,
**dict(group_size=group_size, axis=axis),
**dict(groupsize=group_size, axis=axis),
}

quant_config = BaseQuantizeConfig(
Expand Down Expand Up @@ -172,7 +172,7 @@ def test_mixed_mm(
xs = [torch.randn(seqlen, n, dtype=dtype, device=device) for n in Ns]
x_fused = torch.cat(xs, dim=1)
q_ref, k_ref, v_ref = [
ref_proj(x, p, s, z, group_size, kernel_type, transposed=True)
ref_proj(x, p, s, z, groupsize, kernel_type, transposed=True)
for x, p, s, z in zip(xs, packed_ws, scales, zeros)
]
tt_fused = triton_mixed_mm(
Expand All @@ -181,7 +181,7 @@ def test_mixed_mm(
scales_fused.T,
zeros_fused.T,
transposed=True,
group_size=group_size,
groupsize=group_size,
fp8_fast_accum=False,
kernel_type=kernel_type,
)
Expand All @@ -191,7 +191,7 @@ def test_mixed_mm(
x = torch.randn(seqlen, K, dtype=dtype, device=device)

q_ref, k_ref, v_ref = [
ref_proj(x, p, s, z, group_size, kernel_type)
ref_proj(x, p, s, z, groupsize, kernel_type)
for p, s, z in zip(packed_ws, scales, zeros)
]

Expand All @@ -201,7 +201,7 @@ def test_mixed_mm(
scales_fused.T,
zeros_fused.T,
transposed=False,
group_size=group_size,
groupsize=group_size,
fp8_fast_accum=False,
kernel_type=kernel_type,
)
Expand Down
2 changes: 1 addition & 1 deletion test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
def api(mod):
if TORCH_VERSION_AFTER_2_4:
kwargs_copy = kwargs.copy()
kwargs_copy["group_size"] = groupsize
kwargs_copy["groupsize"] = groupsize
del kwargs_copy["groupsize"]
quantize(mod, int4_weight_only(**kwargs_copy))
unwrap_tensor_subclass(mod)
Expand Down
4 changes: 2 additions & 2 deletions test/quantization/test_galore_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize):
bnb_norm = (g.reshape(-1, blocksize) / qstate.absmax[:, None]).reshape(g.shape)

tt_q, tt_norm, tt_absmax = triton_quantize_blockwise(
g, qmap, group_size=blocksize, return_normalized=True
g, qmap, groupsize=blocksize, return_normalized=True
)
tt_check = torch.allclose(ref_bnb, tt_q)

Expand Down Expand Up @@ -87,5 +87,5 @@ def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize):
q, qstate = F.quantize_blockwise(g, code=qmap, blocksize=blocksize)

dq_ref = F.dequantize_blockwise(q, qstate)
dq = triton_dequant_blockwise(q, qmap, qstate.absmax, group_size=blocksize)
dq = triton_dequant_blockwise(q, qmap, qstate.absmax, groupsize=blocksize)
assert torch.allclose(dq, dq_ref)
Loading
Loading