Skip to content

Commit 5ffac6f

Browse files
jamesdev9abetlen
jamesdev9
andcommitted
feat: add support for KV cache quantization options (#1307)
* add KV cache quantization options abetlen/llama-cpp-python#1220 abetlen/llama-cpp-python#1305 * Add ggml_type * Use ggml_type instead of string for quantization * Add server support --------- Co-authored-by: Andrei Betlen <[email protected]>
1 parent 0cc5e4d commit 5ffac6f

File tree

4 files changed

+94
-41
lines changed

4 files changed

+94
-41
lines changed

llama_cpp/llama.py

+18-41
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ def __init__(
105105
draft_model: Optional[LlamaDraftModel] = None,
106106
# Tokenizer Override
107107
tokenizer: Optional[BaseLlamaTokenizer] = None,
108+
# KV cache quantization
109+
type_k: Optional[int] = None,
110+
type_v: Optional[int] = None,
108111
# Misc
109112
verbose: bool = True,
110113
# Extra Params
@@ -172,6 +175,8 @@ def __init__(
172175
draft_model: Optional draft model to use for speculative decoding.
173176
tokenizer: Optional tokenizer to override the default tokenizer from llama.cpp.
174177
verbose: Print verbose output to stderr.
178+
type_k: KV cache data type for K (default: f16)
179+
type_v: KV cache data type for V (default: f16)
175180
176181
Raises:
177182
ValueError: If the model path does not exist.
@@ -298,7 +303,11 @@ def __init__(
298303
) # Must be set to True for speculative decoding
299304
self.context_params.embeddings = embedding # TODO: Rename to embeddings
300305
self.context_params.offload_kqv = offload_kqv
301-
306+
# KV cache quantization
307+
if type_k is not None:
308+
self.context_params.type_k = type_k
309+
if type_v is not None:
310+
self.context_params.type_v = type_v
302311
# Sampling Params
303312
self.last_n_tokens_size = last_n_tokens_size
304313

@@ -1724,6 +1733,7 @@ def __getstate__(self):
17241733
n_threads=self.context_params.n_threads,
17251734
n_threads_batch=self.context_params.n_threads_batch,
17261735
rope_scaling_type=self.context_params.rope_scaling_type,
1736+
pooling_type=self.context_params.pooling_type,
17271737
rope_freq_base=self.context_params.rope_freq_base,
17281738
rope_freq_scale=self.context_params.rope_freq_scale,
17291739
yarn_ext_factor=self.context_params.yarn_ext_factor,
@@ -1733,6 +1743,7 @@ def __getstate__(self):
17331743
yarn_orig_ctx=self.context_params.yarn_orig_ctx,
17341744
logits_all=self.context_params.logits_all,
17351745
embedding=self.context_params.embeddings,
1746+
offload_kqv=self.context_params.offload_kqv,
17361747
# Sampling Params
17371748
last_n_tokens_size=self.last_n_tokens_size,
17381749
# LoRA Params
@@ -1744,51 +1755,17 @@ def __getstate__(self):
17441755
# Chat Format Params
17451756
chat_format=self.chat_format,
17461757
chat_handler=self.chat_handler,
1758+
# Speculative Decidng
1759+
draft_model=self.draft_model,
1760+
# KV cache quantization
1761+
type_k=self.context_params.type_k,
1762+
type_v=self.context_params.type_v,
17471763
# Misc
17481764
verbose=self.verbose,
17491765
)
17501766

17511767
def __setstate__(self, state):
1752-
self.__init__(
1753-
model_path=state["model_path"],
1754-
# Model Params
1755-
n_gpu_layers=state["n_gpu_layers"],
1756-
split_mode=state["split_mode"],
1757-
main_gpu=state["main_gpu"],
1758-
tensor_split=state["tensor_split"],
1759-
vocab_only=state["vocab_only"],
1760-
use_mmap=state["use_mmap"],
1761-
use_mlock=state["use_mlock"],
1762-
kv_overrides=state["kv_overrides"],
1763-
# Context Params
1764-
seed=state["seed"],
1765-
n_ctx=state["n_ctx"],
1766-
n_batch=state["n_batch"],
1767-
n_threads=state["n_threads"],
1768-
n_threads_batch=state["n_threads_batch"],
1769-
rope_freq_base=state["rope_freq_base"],
1770-
rope_freq_scale=state["rope_freq_scale"],
1771-
rope_scaling_type=state["rope_scaling_type"],
1772-
yarn_ext_factor=state["yarn_ext_factor"],
1773-
yarn_attn_factor=state["yarn_attn_factor"],
1774-
yarn_beta_fast=state["yarn_beta_fast"],
1775-
yarn_beta_slow=state["yarn_beta_slow"],
1776-
yarn_orig_ctx=state["yarn_orig_ctx"],
1777-
logits_all=state["logits_all"],
1778-
embedding=state["embedding"],
1779-
# Sampling Params
1780-
last_n_tokens_size=state["last_n_tokens_size"],
1781-
# LoRA Params
1782-
lora_base=state["lora_base"],
1783-
lora_path=state["lora_path"],
1784-
# Backend Params
1785-
numa=state["numa"],
1786-
# Chat Format Params
1787-
chat_format=state["chat_format"],
1788-
chat_handler=state["chat_handler"],
1789-
# Misc
1790-
verbose=state["verbose"],
1791-
)
1768+
self.__init__(**state)
17921769

17931770
def save_state(self) -> LlamaState:
17941771
assert self._ctx.ctx is not None

llama_cpp/llama_cpp.py

+64
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,70 @@ def byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCDa
141141

142142
byref = ctypes.byref # type: ignore
143143

144+
# from ggml.h
145+
# // NOTE: always add types at the end of the enum to keep backward compatibility
146+
# enum ggml_type {
147+
# GGML_TYPE_F32 = 0,
148+
# GGML_TYPE_F16 = 1,
149+
# GGML_TYPE_Q4_0 = 2,
150+
# GGML_TYPE_Q4_1 = 3,
151+
# // GGML_TYPE_Q4_2 = 4, support has been removed
152+
# // GGML_TYPE_Q4_3 = 5, support has been removed
153+
# GGML_TYPE_Q5_0 = 6,
154+
# GGML_TYPE_Q5_1 = 7,
155+
# GGML_TYPE_Q8_0 = 8,
156+
# GGML_TYPE_Q8_1 = 9,
157+
# GGML_TYPE_Q2_K = 10,
158+
# GGML_TYPE_Q3_K = 11,
159+
# GGML_TYPE_Q4_K = 12,
160+
# GGML_TYPE_Q5_K = 13,
161+
# GGML_TYPE_Q6_K = 14,
162+
# GGML_TYPE_Q8_K = 15,
163+
# GGML_TYPE_IQ2_XXS = 16,
164+
# GGML_TYPE_IQ2_XS = 17,
165+
# GGML_TYPE_IQ3_XXS = 18,
166+
# GGML_TYPE_IQ1_S = 19,
167+
# GGML_TYPE_IQ4_NL = 20,
168+
# GGML_TYPE_IQ3_S = 21,
169+
# GGML_TYPE_IQ2_S = 22,
170+
# GGML_TYPE_IQ4_XS = 23,
171+
# GGML_TYPE_I8 = 24,
172+
# GGML_TYPE_I16 = 25,
173+
# GGML_TYPE_I32 = 26,
174+
# GGML_TYPE_I64 = 27,
175+
# GGML_TYPE_F64 = 28,
176+
# GGML_TYPE_IQ1_M = 29,
177+
# GGML_TYPE_COUNT,
178+
# };
179+
GGML_TYPE_F32 = 0
180+
GGML_TYPE_F16 = 1
181+
GGML_TYPE_Q4_0 = 2
182+
GGML_TYPE_Q4_1 = 3
183+
GGML_TYPE_Q5_0 = 6
184+
GGML_TYPE_Q5_1 = 7
185+
GGML_TYPE_Q8_0 = 8
186+
GGML_TYPE_Q8_1 = 9
187+
GGML_TYPE_Q2_K = 10
188+
GGML_TYPE_Q3_K = 11
189+
GGML_TYPE_Q4_K = 12
190+
GGML_TYPE_Q5_K = 13
191+
GGML_TYPE_Q6_K = 14
192+
GGML_TYPE_Q8_K = 15
193+
GGML_TYPE_IQ2_XXS = 16
194+
GGML_TYPE_IQ2_XS = 17
195+
GGML_TYPE_IQ3_XXS = 18
196+
GGML_TYPE_IQ1_S = 19
197+
GGML_TYPE_IQ4_NL = 20
198+
GGML_TYPE_IQ3_S = 21
199+
GGML_TYPE_IQ2_S = 22
200+
GGML_TYPE_IQ4_XS = 23
201+
GGML_TYPE_I8 = 24
202+
GGML_TYPE_I16 = 25
203+
GGML_TYPE_I32 = 26
204+
GGML_TYPE_I64 = 27
205+
GGML_TYPE_F64 = 28
206+
GGML_TYPE_IQ1_M = 29
207+
GGML_TYPE_COUNT = 30
144208

145209
# from ggml-backend.h
146210
# typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);

llama_cpp/server/model.py

+3
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama:
175175
chat_handler=chat_handler,
176176
# Speculative Decoding
177177
draft_model=draft_model,
178+
# KV Cache Quantization
179+
type_k=settings.type_k,
180+
type_v=settings.type_v,
178181
# Tokenizer
179182
tokenizer=tokenizer,
180183
# Misc

llama_cpp/server/settings.py

+9
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,15 @@ class ModelSettings(BaseSettings):
159159
default=10,
160160
description="Number of tokens to predict using the draft model.",
161161
)
162+
# KV Cache Quantization
163+
type_k: Optional[int] = Field(
164+
default=None,
165+
description="Type of the key cache quantization.",
166+
)
167+
type_v: Optional[int] = Field(
168+
default=None,
169+
description="Type of the value cache quantization.",
170+
)
162171
# Misc
163172
verbose: bool = Field(
164173
default=True, description="Whether to print debug information."

0 commit comments

Comments
 (0)