Skip to content

Commit 4a0c69c

Browse files
committed
testing kv_cache quantization
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 4f53882 commit 4a0c69c

File tree

3 files changed

+38
-17
lines changed

3 files changed

+38
-17
lines changed

torchao/_models/llama/benchmarks.sh

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,23 @@ export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder
22

33

44
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
5-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
6-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
7-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
5+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
6+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
7+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
88
# in readme
9-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
10-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
9+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
10+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
1111
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
12-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
13-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt
12+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
13+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt
1414

15-
export MODEL_REPO=meta-llama/Meta-Llama-3-8B
16-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
17-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
18-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
19-
# in readme
20-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
21-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
22-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
23-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
24-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt
15+
# export MODEL_REPO=meta-llama/Meta-Llama-3-8B
16+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
17+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
18+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
19+
# # in readme
20+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
21+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
22+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
23+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
24+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt

torchao/_models/llama/generate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,13 +337,15 @@ def callback(x):
337337
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
338338
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
339339
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant')
340+
parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache')
340341
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
341342
parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)')
342343
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
343344
parser.add_argument('--device', type=str, default=default_device, help='Device to use')
344345
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
345346
parser.add_argument('--write_result', type=Path, default=None, help='Path where to write the result')
346347

348+
347349
args = parser.parse_args()
348350
main(
349351
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,

torchao/_models/llama/model.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,25 @@ def update(self, input_pos, k_val, v_val):
9797

9898
return k_out, v_out
9999

100+
# class QuantizedKVCache(nn.Module):
101+
# def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
102+
# super().__init__()
103+
# cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
104+
# self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=torch.uint8))
105+
# self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=torch.uint8))
106+
# self.register_buffer('k_cache_scale', torch.ones(cache_shape, dtype=torch.bfloat16))
107+
# self.register_buffer('v_cache_scale', torch.ones(cache_shape, dtype=torch.bfloat16))
108+
109+
# def update(self, input_pos, k_val, v_val):
110+
# k_out = self.k_cache
111+
# v_out = self.v_cache
112+
# k_out[:, :, input_pos] = k_val
113+
# v_out[:, :, input_pos] = v_val
114+
115+
# @classmethod
116+
# def from_kv_cache(cls, kv_cache):
117+
118+
100119
class Transformer(nn.Module):
101120
def __init__(self, config: ModelArgs) -> None:
102121
super().__init__()

0 commit comments

Comments
 (0)