@@ -133,19 +133,20 @@ def make_tensors_list() -> List[str]:
133
133
134
134
def find_n_mult (n_ff : int , n_embd : int ) -> int :
135
135
# hardcoded magic range
136
- for n_mult in range (256 , 1 , - 1 ):
136
+ for n_mult in range (8192 , 1 , - 1 ):
137
137
calc_ff = (((8 * n_embd ) // 3 + n_mult - 1 ) // n_mult )* n_mult
138
138
if calc_ff == n_ff :
139
139
return n_mult
140
140
raise Exception (f"failed to find n_mult for (n_ff={ n_ff } , n_embd={ n_embd } )." )
141
141
142
142
@dataclass
143
143
class Params :
144
- n_vocab : int
145
- n_embd : int
146
- n_mult : int
147
- n_head : int
148
- n_layer : int
144
+ n_vocab : int
145
+ n_embd : int
146
+ n_mult : int
147
+ n_head : int
148
+ n_layer : int
149
+ n_kv_head : Optional [int ] # This parameter is only used for Llama 2
149
150
150
151
@staticmethod
151
152
def guessed (model : 'LazyModel' ) -> 'Params' :
@@ -167,11 +168,12 @@ def guessed(model: 'LazyModel') -> 'Params':
167
168
n_head = n_embd // 128 # guessed
168
169
169
170
return Params (
170
- n_vocab = n_vocab ,
171
- n_embd = n_embd ,
172
- n_mult = 256 ,
173
- n_head = n_head ,
174
- n_layer = n_layer ,
171
+ n_vocab = n_vocab ,
172
+ n_embd = n_embd ,
173
+ n_mult = 256 ,
174
+ n_head = n_head ,
175
+ n_layer = n_layer ,
176
+ n_kv_head = None ,
175
177
)
176
178
177
179
@staticmethod
@@ -183,15 +185,17 @@ def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
183
185
n_head = config ["num_attention_heads" ];
184
186
n_layer = config ["num_hidden_layers" ];
185
187
n_ff = config ["intermediate_size" ];
188
+ n_kv_head = config .get ("num_key_value_heads" )
186
189
187
190
n_mult = find_n_mult (n_ff , n_embd );
188
191
189
192
return Params (
190
- n_vocab = n_vocab ,
191
- n_embd = n_embd ,
192
- n_mult = n_mult ,
193
- n_head = n_head ,
194
- n_layer = n_layer ,
193
+ n_vocab = n_vocab ,
194
+ n_embd = n_embd ,
195
+ n_mult = n_mult ,
196
+ n_head = n_head ,
197
+ n_layer = n_layer ,
198
+ n_kv_head = n_kv_head ,
195
199
)
196
200
197
201
# LLaMA v2 70B params.json
@@ -200,21 +204,22 @@ def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
200
204
def loadOriginalParamsJson (model : 'LazyModel' , config_path : 'Path' ) -> 'Params' :
201
205
config = json .load (open (config_path ))
202
206
203
- n_vocab = config ["vocab_size" ];
204
- n_embd = config ["dim" ];
205
- n_head = config ["n_heads" ];
206
- n_layer = config ["n_layers" ];
207
- n_mult = config ["multiple_of" ];
207
+ n_vocab = config ["vocab_size" ];
208
+ n_embd = config ["dim" ];
209
+ n_head = config ["n_heads" ];
210
+ n_layer = config ["n_layers" ];
211
+ n_mult = config ["multiple_of" ];
208
212
209
213
if n_vocab == - 1 :
210
214
n_vocab = model ["tok_embeddings.weight" ].shape [0 ]
211
215
212
216
return Params (
213
- n_vocab = n_vocab ,
214
- n_embd = n_embd ,
215
- n_mult = n_mult ,
216
- n_head = n_head ,
217
- n_layer = n_layer ,
217
+ n_vocab = n_vocab ,
218
+ n_embd = n_embd ,
219
+ n_mult = n_mult ,
220
+ n_head = n_head ,
221
+ n_layer = n_layer ,
222
+ n_kv_head = None ,
218
223
)
219
224
220
225
@staticmethod
@@ -317,10 +322,12 @@ def __repr__(self) -> str:
317
322
Vocab = Union [SentencePieceVocab , GGMLVocab ]
318
323
319
324
320
- def permute (weights : NDArray , n_head : int ) -> NDArray :
325
+ def permute (weights : NDArray , n_head : int , n_kv_head : Optional [int ] = None ) -> NDArray :
326
+ if n_kv_head is not None and n_head != n_kv_head :
327
+ n_head //= n_kv_head
321
328
return (weights .reshape (n_head , 2 , weights .shape [0 ] // n_head // 2 , * weights .shape [1 :])
322
- .swapaxes (1 , 2 )
323
- .reshape (weights .shape ))
329
+ .swapaxes (1 , 2 )
330
+ .reshape (weights .shape ))
324
331
325
332
326
333
def dequantize_q4 (qvalues_pack32 : NDArray , scales : NDArray , addends : Optional [NDArray ], g_idx : Optional [NDArray ]) -> NDArray :
@@ -368,7 +375,7 @@ class Tensor(metaclass=ABCMeta):
368
375
@abstractmethod
369
376
def astype (self , data_type : DataType ) -> 'Tensor' : ...
370
377
@abstractmethod
371
- def permute (self , n_head : int ) -> 'Tensor' : ...
378
+ def permute (self , n_head : int , n_kv_head : Optional [ int ] = None ) -> 'Tensor' : ...
372
379
@abstractmethod
373
380
def permute_part (self , n_part : int , n_head : int ) -> 'UnquantizedTensor' : ...
374
381
@abstractmethod
@@ -406,8 +413,8 @@ def part(self, n_part: int) -> 'UnquantizedTensor':
406
413
r = self .ndarray .shape [0 ] // 3
407
414
return UnquantizedTensor (self .ndarray [r * n_part : r * n_part + r , ...])
408
415
409
- def permute (self , n_head : int ) -> 'UnquantizedTensor' :
410
- return UnquantizedTensor (permute (self .ndarray , n_head ))
416
+ def permute (self , n_head : int , n_kv_head : Optional [ int ] = None ) -> 'UnquantizedTensor' :
417
+ return UnquantizedTensor (permute (self .ndarray , n_head , n_kv_head ))
411
418
412
419
413
420
def load_unquantized (lazy_tensor : 'LazyTensor' , expected_dtype : Any = None , convert : bool = False ) -> NDArray :
@@ -455,26 +462,27 @@ def astype(self, data_type: DataType) -> Tensor:
455
462
def to_ggml (self ) -> 'GGMLQuantizedTensor' :
456
463
return self
457
464
458
- def permute (self , n_head : int ) -> 'GGMLQuantizedTensor' :
459
- return GGMLQuantizedTensor (permute (self .ndarray , n_head ), self .shape , self .data_type )
465
+ def permute (self , n_head : int , n_kv_head : Optional [ int ] = None ) -> 'GGMLQuantizedTensor' :
466
+ return GGMLQuantizedTensor (permute (self .ndarray , n_head , n_kv_head ), self .shape , self .data_type )
460
467
461
468
462
469
GGMLCompatibleTensor = Union [UnquantizedTensor , GGMLQuantizedTensor ]
463
470
464
471
465
472
class DeferredPermutedTensor (Tensor ):
466
- def __init__ (self , base : Tensor , n_head : int ) -> None :
473
+ def __init__ (self , base : Tensor , n_head : int , n_kv_head : Optional [ int ] = None ) -> None :
467
474
self .base = base
468
475
self .n_head = n_head
476
+ self .n_kv_head = n_kv_head
469
477
self .data_type = self .base .data_type
470
478
471
479
def astype (self , data_type : DataType ) -> Tensor :
472
- return self .base .astype (data_type ).permute (self .n_head )
480
+ return self .base .astype (data_type ).permute (self .n_head , self . n_kv_head )
473
481
474
482
def to_ggml (self ) -> GGMLCompatibleTensor :
475
- return self .base .to_ggml ().permute (self .n_head )
483
+ return self .base .to_ggml ().permute (self .n_head , self . n_kv_head )
476
484
477
- def permute (self , n_head : int ) -> Tensor :
485
+ def permute (self , n_head : int , n_kv_head : Optional [ int ] = None ) -> Tensor :
478
486
raise Exception ("shouldn't permute twice" )
479
487
480
488
@@ -566,8 +574,8 @@ def regroup(self, new_groupsize: int = 32) -> 'GPTQForLLaMaQuantizedTensor':
566
574
ret .data_type = QuantizedDataType (groupsize = new_groupsize , have_addends = True , have_g_idx = False )
567
575
return ret
568
576
569
- def permute (self , n_head : int ) -> Tensor :
570
- return DeferredPermutedTensor (self , n_head )
577
+ def permute (self , n_head : int , n_kv_head : Optional [ int ] = None ) -> Tensor :
578
+ return DeferredPermutedTensor (self , n_head , n_kv_head )
571
579
572
580
def to_ggml (self ) -> GGMLQuantizedTensor :
573
581
# The output format looks like this:
@@ -698,10 +706,10 @@ def merge_multifile_models(models_plus: List[ModelPlus]) -> ModelPlus:
698
706
return ModelPlus (model , paths , format , vocab )
699
707
700
708
701
- def permute_lazy (lazy_tensor : LazyTensor , n_head : int ) -> LazyTensor :
709
+ def permute_lazy (lazy_tensor : LazyTensor , n_head : int , n_kv_head : Optional [ int ] = None ) -> LazyTensor :
702
710
def load () -> Tensor :
703
- return lazy_tensor .load ().permute (n_head )
704
- return LazyTensor (load , lazy_tensor .shape , lazy_tensor .data_type , f'permute({ n_head } ) ' + lazy_tensor .description )
711
+ return lazy_tensor .load ().permute (n_head , n_kv_head )
712
+ return LazyTensor (load , lazy_tensor .shape , lazy_tensor .data_type , f'permute({ n_head } , { n_kv_head } ) ' + lazy_tensor .description )
705
713
706
714
def permute_part_lazy (lazy_tensor : LazyTensor , n_part : int , n_head : int ) -> LazyTensor :
707
715
def load () -> Tensor :
@@ -726,7 +734,7 @@ def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel:
726
734
for i in itertools .count ():
727
735
if f"model.layers.{ i } .self_attn.q_proj.weight" in model :
728
736
out [f"layers.{ i } .attention.wq.weight" ] = permute_lazy (model [f"model.layers.{ i } .self_attn.q_proj.weight" ], params .n_head )
729
- out [f"layers.{ i } .attention.wk.weight" ] = permute_lazy (model [f"model.layers.{ i } .self_attn.k_proj.weight" ], params .n_head )
737
+ out [f"layers.{ i } .attention.wk.weight" ] = permute_lazy (model [f"model.layers.{ i } .self_attn.k_proj.weight" ], params .n_head , params . n_kv_head )
730
738
out [f"layers.{ i } .attention.wv.weight" ] = model [f"model.layers.{ i } .self_attn.v_proj.weight" ]
731
739
elif f"model.layers.{ i } .self_attn.W_pack.weight" in model :
732
740
out [f"layers.{ i } .attention.wq.weight" ] = permute_part_lazy (model [f"model.layers.{ i } .self_attn.W_pack.weight" ], 0 , params .n_head )
0 commit comments