Skip to content

Commit 1ec9b07

Browse files
committed
testing kv_cache quantization [WIP]
Summary: the peak memory improvement is extremely small, tried a few things to fix this but didn't have any luck. Accuracy is very poor (text is unintelligible) tried to leave most recent token not quantized (since we have full fidelity information for whatever the current token is). That didn't solve the issue and resulted in a significant memory increase, may need to try affine quantization but currently more concerned with the lack of memory improvement. (see benchmark_results.txt for the results see kv_quant: True vs kv_quant: False for comparison.) i also took a memory trace you can get with (if you're a meta employee) jf download GCqU9BqGNUybzv8CABWUzUtOiPZ5bsIXAAAz --file "mem_trace_kvq.html" Test Plan: sh benchmarks.sh Reviewers: Subscribers: Tasks: Tags:
1 parent 3804d74 commit 1ec9b07

File tree

4 files changed

+90
-27
lines changed

4 files changed

+90
-27
lines changed

torchao/_models/llama/benchmark_results.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,15 @@
1717
20240619123652, tok/s=139.76, mem/s=1051.02 GB/s, peak_mem=10.42 GB, model_size= 7.52 GB quant: int8wo, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
1818
20240619123847, tok/s=179.44, mem/s= 757.60 GB/s, peak_mem= 6.62 GB, model_size= 4.22 GB quant: int4wo-64, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
1919
20240619131959, tok/s=137.71, mem/s=1037.74 GB/s, peak_mem=11.08 GB, model_size= 7.54 GB quant: autoquant, mod: Meta-Llama-3-8B, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20+
21+
# done with quantization of latest token
22+
20240718131341, tok/s=108.87, mem/s=1438.62 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
23+
20240718131549, tok/s=103.15, mem/s=1363.06 GB/s, peak_mem=13.86 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
24+
20240718131820, tok/s=163.84, mem/s=1084.89 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
25+
20240718132103, tok/s=154.76, mem/s=1024.78 GB/s, peak_mem= 8.93 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
26+
27+
# done with full accuracy for latest token
28+
20240718150644, tok/s=109.23, mem/s=1443.43 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
29+
20240718151152, tok/s=100.29, mem/s=1325.29 GB/s, peak_mem=14.14 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
30+
20240718151349, tok/s=166.08, mem/s=1099.70 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
31+
20240718152147, tok/s=140.85, mem/s= 932.66 GB/s, peak_mem= 9.21 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8

torchao/_models/llama/benchmarks.sh

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
66
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
77
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
88
# in readme
9-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
9+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
1010
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
1111
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
1212
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
@@ -22,3 +22,9 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
2222
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
2323
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
2424
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt
25+
26+
#####
27+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
28+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization
29+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
30+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt --kv_cache_quantization

torchao/_models/llama/generate.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,11 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc
6868
next_token, next_prob = decode_one_token(
6969
model, cur_token, input_pos, **sampling_kwargs
7070
)
71+
next_token, next_prob = next_token.clone(), next_prob.clone()
7172
input_pos += 1
72-
new_tokens.append(next_token.clone())
73+
new_tokens.append(next_token)
7374
callback(new_tokens[-1])
74-
new_probs.append(next_prob.clone())
75+
new_probs.append(next_prob)
7576
cur_token = next_token.view(1, -1)
7677

7778
return new_tokens, new_probs
@@ -88,23 +89,32 @@ def generate(
8889
*,
8990
interactive: bool,
9091
callback = lambda x: x,
92+
kv_cache_quantization: bool = False,
9193
**sampling_kwargs
9294
) -> torch.Tensor:
9395
"""
9496
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
9597
"""
96-
9798
# create an empty tensor of the expected final shape and fill in the current tokens
9899
device = prompt.device
99100
T = prompt.numel()
100101
T_new = T + max_new_tokens
101102
seq = torch.empty(T_new, dtype=prompt.dtype, device=device)
102103
seq[:T] = prompt.view(-1)
103-
104104
# setup model cache
105105
max_seq_length = min(T_new, model.config.block_size) if not interactive else 350
106106
with torch.device(device):
107107
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
108+
if kv_cache_quantization:
109+
from model import QuantizedKVCache
110+
# go through the model and do the swaps
111+
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
112+
_replace_with_custom_fn_if_matches_filter(
113+
model,
114+
QuantizedKVCache.from_float,
115+
lambda x, y: isinstance(x, torchao._models.llama.model.KVCache),
116+
)
117+
108118

109119
# format model input
110120
x, input_pos = prepare_inputs_for_model(prompt, max_new_tokens)
@@ -147,6 +157,7 @@ def main(
147157
temperature: float = 0.8,
148158
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
149159
quantization: Optional[str] = None,
160+
kv_cache_quantization: bool = False,
150161
compile: bool = True,
151162
compile_prefill: bool = False,
152163
profile: Optional[Path] = None,
@@ -157,6 +168,7 @@ def main(
157168
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
158169
"""
159170

171+
# torch.cuda.memory._record_memory_history(True,trace_alloc_max_entries=1000000, trace_alloc_record_context=True)
160172
torchao.quantization.utils.recommended_inductor_config_setter()
161173

162174
assert checkpoint_path.is_file(), checkpoint_path
@@ -179,9 +191,7 @@ def main(
179191
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
180192
prompt_length = encoded.size(0)
181193

182-
torch.manual_seed(1234)
183-
184-
194+
torch.manual_seed(1234)
185195
if quantization:
186196
from torchao.quantization.quant_api import (
187197
quantize_,
@@ -276,7 +286,14 @@ def callback(x):
276286
callback=callback,
277287
temperature=temperature,
278288
top_k=top_k,
289+
kv_cache_quantization=kv_cache_quantization,
279290
)
291+
# if i==3:
292+
# snapshot = torch.cuda.memory._snapshot()
293+
# from pickle import dump
294+
# with open("mem_trace_kvq_no_comp" + '.pickle', 'wb') as f:
295+
# dump(snapshot, f)
296+
# breakpoint()
280297
if i == -1:
281298
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
282299
continue
@@ -305,12 +322,13 @@ def callback(x):
305322
print(f"Model Size: {model_size:.02f} GB")
306323
if write_result:
307324
result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB "
308-
result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} "
325+
result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} "
309326
result_txt += f"repro: python generate.py "
310327
result_txt += f"--quantization {quantization} " if quantization else ""
311328
result_txt += f"--checkpoint_path {checkpoint_path} "
312329
result_txt += f"--device {device} "
313330
result_txt += f"--precision {precision} "
331+
result_txt += f"--kv_cache_quantization " if kv_cache_quantization else ""
314332
result_txt += f"--compile " if compile else ""
315333
result_txt += f"--compile_prefill " if compile_prefill else ""
316334
result_txt += f"--profile {profile} " if profile else ""
@@ -348,5 +366,5 @@ def callback(x):
348366
args = parser.parse_args()
349367
main(
350368
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
351-
args.temperature, args.checkpoint_path, args.quantization, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result
369+
args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result
352370
)

torchao/_models/llama/model.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torc
8585
def update(self, input_pos, k_val, v_val):
8686
# input_pos: [S], k_val: [B, H, S, D]
8787
assert input_pos.shape[0] == k_val.shape[2]
88-
8988
if use_index_put_for_kv_cache:
9089
k_out = torch.ops.aten.index_put_(self.k_cache, [None, None, input_pos], k_val)
9190
v_out = torch.ops.aten.index_put_(self.v_cache, [None, None, input_pos], v_val)
@@ -97,23 +96,51 @@ def update(self, input_pos, k_val, v_val):
9796

9897
return k_out, v_out
9998

100-
# class QuantizedKVCache(nn.Module):
101-
# def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
102-
# super().__init__()
103-
# cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
104-
# self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=torch.uint8))
105-
# self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=torch.uint8))
106-
# self.register_buffer('k_cache_scale', torch.ones(cache_shape, dtype=torch.bfloat16))
107-
# self.register_buffer('v_cache_scale', torch.ones(cache_shape, dtype=torch.bfloat16))
99+
100+
# (Pdb) p k_val.shape
101+
# torch.Size([1, 32, 6, 128])
102+
# (Pdb) p self.k_cache.shape
103+
# torch.Size([1, 32, 208, 128]) so want final size to be 1,32,208,[1]
104+
105+
from torchao.quantization.quant_primitives import quantize_affine, dequantize_affine
106+
from torchao.quantization.utils import quantize_activation_per_token_absmax
107+
108+
class QuantizedKVCache(nn.Module):
109+
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, scale_dtype=torch.bfloat16):
110+
super().__init__()
111+
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
112+
scale_shape = (max_batch_size, n_heads, max_seq_length, 1)
113+
self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=torch.int8))
114+
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=torch.int8))
115+
self.register_buffer('k_cache_scale', torch.ones(scale_shape, dtype=scale_dtype))
116+
self.register_buffer('v_cache_scale', torch.ones(scale_shape, dtype=scale_dtype))
108117

109-
# def update(self, input_pos, k_val, v_val):
110-
# k_out = self.k_cache
111-
# v_out = self.v_cache
112-
# k_out[:, :, input_pos] = k_val
113-
# v_out[:, :, input_pos] = v_val
114-
115-
# @classmethod
116-
# def from_kv_cache(cls, kv_cache):
118+
def update(self, input_pos, k_val, v_val):
119+
# k_out = self.k_cache*self.k_cache_scale
120+
# v_out = self.v_cache*self.v_cache_scale
121+
# k_out[:, :, input_pos] = k_val
122+
# v_out[:, :, input_pos] = v_val
123+
124+
q_k_val, k_scale = quantize_activation_per_token_absmax(k_val)
125+
self.k_cache[:, :, input_pos] = q_k_val
126+
self.k_cache_scale[:, :, input_pos] = k_scale.unsqueeze(-1)
127+
del k_val
128+
129+
q_v_val, v_scale = quantize_activation_per_token_absmax(v_val)
130+
self.k_cache[:, :, input_pos] = q_v_val
131+
self.k_cache_scale[:, :, input_pos] = v_scale.unsqueeze(-1)
132+
del v_val
133+
134+
# return k_out, v_out
135+
return self.k_cache*self.k_cache_scale, self.v_cache*self.v_cache_scale
136+
137+
@classmethod
138+
def from_float(cls, kv_cache):
139+
cache_shape = kv_cache.k_cache.shape
140+
max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
141+
scale_dtype = kv_cache.k_cache.dtype
142+
return cls(max_batch_size, max_seq_length, n_heads, head_dim, scale_dtype)
143+
117144

118145

119146
class Transformer(nn.Module):

0 commit comments

Comments
 (0)