@@ -68,10 +68,11 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc
68
68
next_token , next_prob = decode_one_token (
69
69
model , cur_token , input_pos , ** sampling_kwargs
70
70
)
71
+ next_token , next_prob = next_token .clone (), next_prob .clone ()
71
72
input_pos += 1
72
- new_tokens .append (next_token . clone () )
73
+ new_tokens .append (next_token )
73
74
callback (new_tokens [- 1 ])
74
- new_probs .append (next_prob . clone () )
75
+ new_probs .append (next_prob )
75
76
cur_token = next_token .view (1 , - 1 )
76
77
77
78
return new_tokens , new_probs
@@ -88,23 +89,32 @@ def generate(
88
89
* ,
89
90
interactive : bool ,
90
91
callback = lambda x : x ,
92
+ kv_cache_quantization : bool = False ,
91
93
** sampling_kwargs
92
94
) -> torch .Tensor :
93
95
"""
94
96
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
95
97
"""
96
-
97
98
# create an empty tensor of the expected final shape and fill in the current tokens
98
99
device = prompt .device
99
100
T = prompt .numel ()
100
101
T_new = T + max_new_tokens
101
102
seq = torch .empty (T_new , dtype = prompt .dtype , device = device )
102
103
seq [:T ] = prompt .view (- 1 )
103
-
104
104
# setup model cache
105
105
max_seq_length = min (T_new , model .config .block_size ) if not interactive else 350
106
106
with torch .device (device ):
107
107
model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
108
+ if kv_cache_quantization :
109
+ from model import QuantizedKVCache
110
+ # go through the model and do the swaps
111
+ from torchao .quantization .quant_api import _replace_with_custom_fn_if_matches_filter
112
+ _replace_with_custom_fn_if_matches_filter (
113
+ model ,
114
+ QuantizedKVCache .from_float ,
115
+ lambda x , y : isinstance (x , torchao ._models .llama .model .KVCache ),
116
+ )
117
+
108
118
109
119
# format model input
110
120
x , input_pos = prepare_inputs_for_model (prompt , max_new_tokens )
@@ -147,6 +157,7 @@ def main(
147
157
temperature : float = 0.8 ,
148
158
checkpoint_path : Path = Path ("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth" ),
149
159
quantization : Optional [str ] = None ,
160
+ kv_cache_quantization : bool = False ,
150
161
compile : bool = True ,
151
162
compile_prefill : bool = False ,
152
163
profile : Optional [Path ] = None ,
@@ -157,6 +168,7 @@ def main(
157
168
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
158
169
"""
159
170
171
+ # torch.cuda.memory._record_memory_history(True,trace_alloc_max_entries=1000000, trace_alloc_record_context=True)
160
172
torchao .quantization .utils .recommended_inductor_config_setter ()
161
173
162
174
assert checkpoint_path .is_file (), checkpoint_path
@@ -179,9 +191,7 @@ def main(
179
191
encoded = encode_tokens (tokenizer , prompt , bos = True , device = device )
180
192
prompt_length = encoded .size (0 )
181
193
182
- torch .manual_seed (1234 )
183
-
184
-
194
+ torch .manual_seed (1234 )
185
195
if quantization :
186
196
from torchao .quantization .quant_api import (
187
197
quantize_ ,
@@ -276,7 +286,14 @@ def callback(x):
276
286
callback = callback ,
277
287
temperature = temperature ,
278
288
top_k = top_k ,
289
+ kv_cache_quantization = kv_cache_quantization ,
279
290
)
291
+ # if i==3:
292
+ # snapshot = torch.cuda.memory._snapshot()
293
+ # from pickle import dump
294
+ # with open("mem_trace_kvq_no_comp" + '.pickle', 'wb') as f:
295
+ # dump(snapshot, f)
296
+ # breakpoint()
280
297
if i == - 1 :
281
298
print (f"Compilation time: { time .perf_counter () - t0 :.2f} seconds" )
282
299
continue
@@ -305,12 +322,13 @@ def callback(x):
305
322
print (f"Model Size: { model_size :.02f} GB" )
306
323
if write_result :
307
324
result_txt = f"\n { datetime .today ().strftime ('%Y%m%d%H%M%S' )} , tok/s={ tokpersec :6.2f} , mem/s={ bandwidth :7.2f} GB/s, peak_mem={ mem :5.2f} GB, model_size={ model_size :5.2f} GB "
308
- result_txt += f"quant: { quantization } , mod: { checkpoint_path .parent .name } , compile: { compile } , compile_prefill: { compile_prefill } , dtype: { precision } , device: { device } "
325
+ result_txt += f"quant: { quantization } , mod: { checkpoint_path .parent .name } , kv_quant: { kv_cache_quantization } , compile: { compile } , compile_prefill: { compile_prefill } , dtype: { precision } , device: { device } "
309
326
result_txt += f"repro: python generate.py "
310
327
result_txt += f"--quantization { quantization } " if quantization else ""
311
328
result_txt += f"--checkpoint_path { checkpoint_path } "
312
329
result_txt += f"--device { device } "
313
330
result_txt += f"--precision { precision } "
331
+ result_txt += f"--kv_cache_quantization " if kv_cache_quantization else ""
314
332
result_txt += f"--compile " if compile else ""
315
333
result_txt += f"--compile_prefill " if compile_prefill else ""
316
334
result_txt += f"--profile { profile } " if profile else ""
@@ -348,5 +366,5 @@ def callback(x):
348
366
args = parser .parse_args ()
349
367
main (
350
368
args .prompt , args .interactive , args .num_samples , args .max_new_tokens , args .top_k ,
351
- args .temperature , args .checkpoint_path , args .quantization , args .compile , args .compile_prefill , args .profile , args .device , args .precision , args .write_result
369
+ args .temperature , args .checkpoint_path , args .quantization , args .kv_cache_quantization , args . compile , args .compile_prefill , args .profile , args .device , args .precision , args .write_result
352
370
)
0 commit comments