@@ -105,6 +105,9 @@ def __init__(
105
105
draft_model : Optional [LlamaDraftModel ] = None ,
106
106
# Tokenizer Override
107
107
tokenizer : Optional [BaseLlamaTokenizer ] = None ,
108
+ # KV cache quantization
109
+ type_k : Optional [int ] = None ,
110
+ type_v : Optional [int ] = None ,
108
111
# Misc
109
112
verbose : bool = True ,
110
113
# Extra Params
@@ -172,6 +175,8 @@ def __init__(
172
175
draft_model: Optional draft model to use for speculative decoding.
173
176
tokenizer: Optional tokenizer to override the default tokenizer from llama.cpp.
174
177
verbose: Print verbose output to stderr.
178
+ type_k: KV cache data type for K (default: f16)
179
+ type_v: KV cache data type for V (default: f16)
175
180
176
181
Raises:
177
182
ValueError: If the model path does not exist.
@@ -298,7 +303,11 @@ def __init__(
298
303
) # Must be set to True for speculative decoding
299
304
self .context_params .embeddings = embedding # TODO: Rename to embeddings
300
305
self .context_params .offload_kqv = offload_kqv
301
-
306
+ # KV cache quantization
307
+ if type_k is not None :
308
+ self .context_params .type_k = type_k
309
+ if type_v is not None :
310
+ self .context_params .type_v = type_v
302
311
# Sampling Params
303
312
self .last_n_tokens_size = last_n_tokens_size
304
313
@@ -1724,6 +1733,7 @@ def __getstate__(self):
1724
1733
n_threads = self .context_params .n_threads ,
1725
1734
n_threads_batch = self .context_params .n_threads_batch ,
1726
1735
rope_scaling_type = self .context_params .rope_scaling_type ,
1736
+ pooling_type = self .context_params .pooling_type ,
1727
1737
rope_freq_base = self .context_params .rope_freq_base ,
1728
1738
rope_freq_scale = self .context_params .rope_freq_scale ,
1729
1739
yarn_ext_factor = self .context_params .yarn_ext_factor ,
@@ -1733,6 +1743,7 @@ def __getstate__(self):
1733
1743
yarn_orig_ctx = self .context_params .yarn_orig_ctx ,
1734
1744
logits_all = self .context_params .logits_all ,
1735
1745
embedding = self .context_params .embeddings ,
1746
+ offload_kqv = self .context_params .offload_kqv ,
1736
1747
# Sampling Params
1737
1748
last_n_tokens_size = self .last_n_tokens_size ,
1738
1749
# LoRA Params
@@ -1744,51 +1755,17 @@ def __getstate__(self):
1744
1755
# Chat Format Params
1745
1756
chat_format = self .chat_format ,
1746
1757
chat_handler = self .chat_handler ,
1758
+ # Speculative Decidng
1759
+ draft_model = self .draft_model ,
1760
+ # KV cache quantization
1761
+ type_k = self .context_params .type_k ,
1762
+ type_v = self .context_params .type_v ,
1747
1763
# Misc
1748
1764
verbose = self .verbose ,
1749
1765
)
1750
1766
1751
1767
def __setstate__ (self , state ):
1752
- self .__init__ (
1753
- model_path = state ["model_path" ],
1754
- # Model Params
1755
- n_gpu_layers = state ["n_gpu_layers" ],
1756
- split_mode = state ["split_mode" ],
1757
- main_gpu = state ["main_gpu" ],
1758
- tensor_split = state ["tensor_split" ],
1759
- vocab_only = state ["vocab_only" ],
1760
- use_mmap = state ["use_mmap" ],
1761
- use_mlock = state ["use_mlock" ],
1762
- kv_overrides = state ["kv_overrides" ],
1763
- # Context Params
1764
- seed = state ["seed" ],
1765
- n_ctx = state ["n_ctx" ],
1766
- n_batch = state ["n_batch" ],
1767
- n_threads = state ["n_threads" ],
1768
- n_threads_batch = state ["n_threads_batch" ],
1769
- rope_freq_base = state ["rope_freq_base" ],
1770
- rope_freq_scale = state ["rope_freq_scale" ],
1771
- rope_scaling_type = state ["rope_scaling_type" ],
1772
- yarn_ext_factor = state ["yarn_ext_factor" ],
1773
- yarn_attn_factor = state ["yarn_attn_factor" ],
1774
- yarn_beta_fast = state ["yarn_beta_fast" ],
1775
- yarn_beta_slow = state ["yarn_beta_slow" ],
1776
- yarn_orig_ctx = state ["yarn_orig_ctx" ],
1777
- logits_all = state ["logits_all" ],
1778
- embedding = state ["embedding" ],
1779
- # Sampling Params
1780
- last_n_tokens_size = state ["last_n_tokens_size" ],
1781
- # LoRA Params
1782
- lora_base = state ["lora_base" ],
1783
- lora_path = state ["lora_path" ],
1784
- # Backend Params
1785
- numa = state ["numa" ],
1786
- # Chat Format Params
1787
- chat_format = state ["chat_format" ],
1788
- chat_handler = state ["chat_handler" ],
1789
- # Misc
1790
- verbose = state ["verbose" ],
1791
- )
1768
+ self .__init__ (** state )
1792
1769
1793
1770
def save_state (self ) -> LlamaState :
1794
1771
assert self ._ctx .ctx is not None
0 commit comments