Skip to content

Commit c38fc4a

Browse files
authored
Bitpackingv2 (#307)
* untested unified pack/unpack * tests written, issues fixed * removed conversion * works with compile + use pytest params * added hqq int4 fp16 mixed matmul benchmark for pack * added more repeats for benchmark and removed unused vars * added 1 more benchmark and now tests pass * added order to unpack and updated tests * removed main code and added a text example * added example * organized benchmarks
1 parent 6a380a3 commit c38fc4a

File tree

3 files changed

+450
-129
lines changed

3 files changed

+450
-129
lines changed

benchmarks/benchmark_bitpacking.py

+227
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
from math import log
2+
import torch
3+
4+
from torchao.prototype.common.bitpacking import pack, unpack
5+
from torchao.dtypes.uint4 import unpack_uint4, pack_uint4
6+
7+
8+
def benchmark(function, num_runs, setup =None):
9+
args = setup()
10+
torch.cuda.synchronize()
11+
start_event = torch.cuda.Event(enable_timing=True)
12+
end_event = torch.cuda.Event(enable_timing=True)
13+
start_event.record()
14+
15+
for _ in range(num_runs):
16+
function(*args)
17+
18+
end_event.record()
19+
torch.cuda.synchronize()
20+
return start_event.elapsed_time(end_event) / num_runs
21+
22+
23+
def test_vs_existing():
24+
def new_():
25+
fake_tensor = torch.randint(0, 2**8, (1, 1024,1024), dtype=torch.uint8).cuda()
26+
packed = pack(fake_tensor, 4, dim=1)
27+
unpacked = unpack(packed, 4, dim=1)
28+
def old_():
29+
fake_tensor = torch.randint(0, 2**8, (1, 1024,1024), dtype=torch.uint8).cuda()
30+
packed = pack_uint4(fake_tensor)
31+
unpacked = unpack_uint4(packed)
32+
new_ = torch.compile(new_, fullgraph=True)
33+
old_ = torch.compile(old_, fullgraph=True)
34+
new_()
35+
old_()
36+
print(f"new: {benchmark(new_, 1000)} ms ")
37+
print(f"old: {benchmark(old_, 1000)} ms")
38+
39+
40+
def test_iso_bitpack():
41+
def load4x(scale=1024):
42+
fake_tensor = torch.randint(0, 2**8, (1, 4*scale,scale), dtype=torch.uint8).cuda()
43+
44+
def load2x(scale=1024):
45+
fake_tensor = torch.randint(0, 2**8, (1, 2*scale,scale), dtype=torch.uint8).cuda()
46+
47+
def loadx(scale=1024):
48+
fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda()
49+
50+
def unpack8to2(scale=1024):
51+
fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda()
52+
unpacked_tensor = unpack_c(fake_tensor, 2, dim=1)
53+
54+
def unpack8to4(scale=1024):
55+
fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda()
56+
unpacked_tensor = unpack_c(fake_tensor, 4, dim=1)
57+
58+
def t8to4wmm(scale=1024):
59+
fake_tensor = torch.randint(0, 2**8, (8, 1024,1024), dtype=torch.uint8).cuda()
60+
unpacked_tensor = unpack_c(fake_tensor, 4, dim=1)
61+
62+
torch._dynamo.config.specialize_int = True
63+
# _unpack_c = torch.compile(_unpack, fullgraph=True)
64+
unpack_c = torch.compile(unpack, fullgraph=True)
65+
66+
scale = [16,64,256,1024,4096]
67+
load4x_times = []
68+
unpack8to2_times = []
69+
load2x_times = []
70+
unpack8to4_times = []
71+
for s in scale:
72+
res = benchmark(load4x, 50, scale=s)
73+
load4x_times.append(res)
74+
print(f"load(1, {4*s},{s}) time: {res} ms")
75+
76+
res=benchmark(unpack8to2, 50, scale=s)
77+
unpack8to2_times.append(res)
78+
print(f"load(1, {s},{s}) unpack uint2 time: {res} ms")
79+
80+
res = benchmark(load2x, 50, scale=s)
81+
load2x_times.append(res)
82+
print(f"load(1, {2*s},{s}) time: {res} ms")
83+
84+
res = benchmark(unpack8to4, 50, scale=s)
85+
unpack8to4_times.append(res)
86+
print(f"load(1, {s},{s}) unpack uint4 time: {res} ms")
87+
print()
88+
89+
# import matplotlib.pyplot as plt
90+
# plt.plot(scale, load4x_times, label="load(1, 4x, x)")
91+
# plt.plot(scale, unpack8to2_times, label="unpack uint8 to uint2")
92+
# plt.plot(scale, load2x_times, label="load(1, 2x, x)")
93+
# plt.plot(scale, unpack8to4_times, label="unpack uint8 to uint4")
94+
# plt.xlabel("scale")
95+
# plt.ylabel("time (ms)")
96+
# plt.yscale("log")
97+
# plt.legend()
98+
# plt.savefig("benchmark_bitpacking.png")
99+
100+
101+
def test_vs_hqqpack():
102+
#requires hqq to be installed
103+
import hqq
104+
import hqq.core.quantize as hqq_quantize
105+
HQQLinear = hqq_quantize.HQQLinear
106+
BaseQuantizeConfig = hqq_quantize.BaseQuantizeConfig
107+
from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm
108+
109+
BASE_QUANT_CONFIG = {
110+
"optimize": True,
111+
"view_as_float": False,
112+
"nbits": 4,
113+
"bitpack": False,
114+
"axis": 1,
115+
}
116+
117+
def mixed_mm(
118+
shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8, pack_fn = True
119+
):
120+
qcfg = {
121+
**BASE_QUANT_CONFIG,
122+
**dict(group_size=group_size, axis=axis),
123+
}
124+
M, N, K = shape
125+
126+
linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda")
127+
128+
quant_config = BaseQuantizeConfig(
129+
quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False
130+
)
131+
quant_config.update({"weight_quant_params": qcfg})
132+
hqq_linear = HQQLinear(linear, quant_config, compute_dtype=dtype, del_orig=False)
133+
W_q, meta = hqq_linear.W_q, hqq_linear.meta
134+
W_q = W_q.to(dtype=quant_dtype)
135+
W_q = (
136+
W_q.reshape(meta["shape"])
137+
if quant_config["weight_quant_params"]["bitpack"] == False
138+
else W_q
139+
)
140+
W_dq = hqq_linear.dequantize()
141+
142+
scales, zeros = meta["scale"], meta["zero"]
143+
scales = scales.reshape(N, -1)
144+
zeros = zeros.reshape(N, -1)
145+
if pack_fn:
146+
packed_w = pack(W_q.T,4,dim=0,order=False)
147+
else:
148+
packed_w = pack_2xint4(W_q.T)
149+
150+
if transposed:
151+
x = torch.randn(M, N, dtype=dtype, device="cuda")
152+
hqq_out = x @ W_dq
153+
154+
tt_out = triton_mixed_mm(
155+
x,
156+
packed_w,
157+
scales.T,
158+
zeros.T,
159+
transposed=True,
160+
group_size=group_size,
161+
fp8_fast_accum=False,
162+
kernel_type=kernel_type,
163+
)
164+
165+
else:
166+
x = torch.randn(M, K, dtype=dtype, device="cuda")
167+
hqq_out = x @ W_dq.T
168+
169+
tt_out = triton_mixed_mm(
170+
x,
171+
packed_w,
172+
scales.T,
173+
zeros.T,
174+
transposed=False,
175+
group_size=group_size,
176+
fp8_fast_accum=False,
177+
kernel_type=kernel_type,
178+
)
179+
180+
shapes = [
181+
[16, 128, 128],
182+
[16, 4096, 4096],
183+
]
184+
group_sizes = [64, 128]
185+
shape = [16, 128, 128]
186+
group_size = 64
187+
pack = torch.compile(pack, fullgraph=True)
188+
for i in range(2):
189+
shape = shapes[i]
190+
group_size = group_sizes[i]
191+
print("linear layer size: ", shape)
192+
print("group size: ", group_size)
193+
# run once to compile
194+
test_mixed_mm(
195+
shape,
196+
group_size,
197+
1,
198+
torch.float16,
199+
True,
200+
"compute_bound",
201+
torch.uint8,
202+
)
203+
# shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8
204+
print("pack time (ms): ", benchmark(test_mixed_mm, 100,
205+
shape,
206+
group_size,
207+
1,
208+
torch.float16,
209+
True,
210+
"compute_bound",
211+
torch.uint8))
212+
213+
print("pack_2xint4 time (ms): ", benchmark(test_mixed_mm, 100,
214+
shape,
215+
group_size,
216+
1,
217+
torch.float16,
218+
True,
219+
"compute_bound", #max autotune doesnt work?
220+
torch.uint8,
221+
pack_fn=False))
222+
print("")
223+
224+
225+
if __name__ == "__main__":
226+
test_vs_existing()
227+

0 commit comments

Comments
 (0)