23
23
24
24
from aqt .jax .v2 import aqt_tensor
25
25
from aqt .jax .v2 import config as aqt_config
26
+ from aqt .jax .v2 .aqt_tensor import QTensor as KVTensor
26
27
from aqt .jax .v2 .flax import aqt_flax
27
28
28
- from MaxText import common_types
29
+ from MaxText .common_types import Array , AxisNames , AxisIdxes , Config , CACHE_BATCH_PREFILL , DType , MODEL_MODE_PREFILL , MODEL_MODE_TRAIN , MODEL_MODE_AUTOREGRESSIVE , CACHE_HEADS_NONE , DECODING_ACTIVE_SEQUENCE_INDICATOR
30
+ from MaxText .common_types import CACHE_BATCH , CACHE_SEQUENCE , CACHE_HEADS , CACHE_KV , CACHE_SCALE_BATCH , CACHE_SCALE_SEQUENCE , CACHE_SCALE_HEADS , CACHE_SCALE_KV
29
31
30
- Array = common_types .Array
31
- AxisNames = common_types .AxisNames
32
- AxisIdxes = common_types .AxisIdxes
33
- Config = common_types .Config
34
- KVTensor = aqt_tensor .QTensor
35
32
36
33
MAX_INT8 = 127.5
37
34
MAX_INT4 = 7.5
38
35
E4M3_MAX = jnp .finfo (jnp .float8_e4m3fn ).max .astype (jnp .float32 )
39
36
40
- CACHE_BATCH_PREFILL = common_types .CACHE_BATCH_PREFILL
41
- CACHE_BATCH = common_types .CACHE_BATCH
42
- CACHE_SEQUENCE = common_types .CACHE_SEQUENCE
43
- CACHE_HEADS = common_types .CACHE_HEADS
44
- CACHE_KV = common_types .CACHE_KV
45
- CACHE_SCALE_BATCH = common_types .CACHE_SCALE_BATCH
46
- CACHE_SCALE_SEQUENCE = common_types .CACHE_SCALE_SEQUENCE
47
- CACHE_SCALE_HEADS = common_types .CACHE_SCALE_HEADS
48
- CACHE_SCALE_KV = common_types .CACHE_SCALE_KV
49
-
50
37
51
38
def reverse_transpose (transposed_array , transpose_axis_order ):
52
39
return jax .numpy .moveaxis (transposed_array , (0 , 1 , 2 , 3 ), transpose_axis_order )
@@ -167,7 +154,7 @@ class KVCache(nn.Module):
167
154
168
155
max_prefill_length : int
169
156
max_target_length : int
170
- dtype : common_types . DType
157
+ dtype : DType
171
158
kv_quant : Optional [KVQuant ] = None
172
159
prefill_cache_logical_axis_names : AxisNames = (CACHE_BATCH_PREFILL , CACHE_SEQUENCE , CACHE_HEADS , CACHE_KV )
173
160
cache_logical_axis_names : AxisNames = (CACHE_BATCH , CACHE_SEQUENCE , CACHE_HEADS , CACHE_KV )
@@ -194,7 +181,7 @@ def _get_prefill_cache_vars(self, batch, key_heads, value_heads, key_head_size,
194
181
cache_length = self .max_prefill_length
195
182
dtype = self ._get_cached_kv_dtype ()
196
183
197
- if model_mode == common_types . MODEL_MODE_PREFILL :
184
+ if model_mode == MODEL_MODE_PREFILL :
198
185
cache_logical_axis_names = self .prefill_cache_logical_axis_names
199
186
else :
200
187
cache_logical_axis_names = self .cache_logical_axis_names
@@ -219,7 +206,7 @@ def _get_prefill_cache_vars(self, batch, key_heads, value_heads, key_head_size,
219
206
cache_shape_value ,
220
207
dtype ,
221
208
)
222
- if model_mode == common_types . MODEL_MODE_PREFILL :
209
+ if model_mode == MODEL_MODE_PREFILL :
223
210
segment_id_axis_names = (CACHE_BATCH_PREFILL , CACHE_SEQUENCE )
224
211
else :
225
212
segment_id_axis_names = (CACHE_BATCH , CACHE_SEQUENCE )
@@ -274,7 +261,7 @@ def _get_ar_cache_vars(self, batch, key_heads, value_heads, key_head_size, value
274
261
)
275
262
cache_length = self .max_target_length - self .max_prefill_length
276
263
277
- if model_mode == common_types . MODEL_MODE_PREFILL :
264
+ if model_mode == MODEL_MODE_PREFILL :
278
265
cache_logical_axis_names = self .prefill_cache_logical_axis_names
279
266
else :
280
267
cache_logical_axis_names = self .cache_logical_axis_names
@@ -311,7 +298,7 @@ def _get_ar_cache_vars(self, batch, key_heads, value_heads, key_head_size, value
311
298
cache_axis_names ,
312
299
)
313
300
314
- if model_mode == common_types . MODEL_MODE_PREFILL :
301
+ if model_mode == MODEL_MODE_PREFILL :
315
302
segment_id_axis_names = (CACHE_BATCH_PREFILL , CACHE_SEQUENCE )
316
303
else :
317
304
segment_id_axis_names = (CACHE_BATCH , CACHE_SEQUENCE )
@@ -401,11 +388,11 @@ def kv_cache_chunked_prefill(
401
388
next_pos = previous_chunk .shape [1 ]
402
389
403
390
cached_prefill_key_vars , cached_prefill_value_vars , cached_prefill_segment_id_var = self ._get_prefill_cache_vars (
404
- batch , key_heads , value_heads , key_head_size , value_head_size , common_types . MODEL_MODE_PREFILL
391
+ batch , key_heads , value_heads , key_head_size , value_head_size , MODEL_MODE_PREFILL
405
392
)
406
393
# TODO: Find a way to not enable the ar cache for prefill mode.
407
394
_ = self ._get_ar_cache_vars (
408
- batch , key_heads , value_heads , key_head_size , value_head_size , common_types . MODEL_MODE_PREFILL
395
+ batch , key_heads , value_heads , key_head_size , value_head_size , MODEL_MODE_PREFILL
409
396
) # initialize it now
410
397
411
398
key_shaped_for_cache = jnp .transpose (key , self .prefill_cache_axis_order )
@@ -488,11 +475,11 @@ def kv_cache_prefill(
488
475
assert key .dtype == value .dtype , "Key and Value Dtypes should match."
489
476
490
477
cached_prefill_key_vars , cached_prefill_value_vars , cached_prefill_segment_id_var = self ._get_prefill_cache_vars (
491
- batch , key_heads , value_heads , key_head_size , value_head_size , common_types . MODEL_MODE_PREFILL
478
+ batch , key_heads , value_heads , key_head_size , value_head_size , MODEL_MODE_PREFILL
492
479
)
493
480
# TODO: Find a way to not enable the ar cache for prefill mode.
494
481
_ = self ._get_ar_cache_vars (
495
- batch , key_heads , value_heads , key_head_size , value_head_size , common_types . MODEL_MODE_PREFILL
482
+ batch , key_heads , value_heads , key_head_size , value_head_size , MODEL_MODE_PREFILL
496
483
) # initialize it now
497
484
498
485
key_shaped_for_cache = jnp .transpose (key , self .prefill_cache_axis_order )
@@ -652,9 +639,7 @@ def kv_cache_autoregressive(
652
639
raise ValueError (f"Sequence length should be 1 during autoregression, got { sequence = } " )
653
640
654
641
cached_ar_key_vars , cached_ar_value_vars , cached_ar_segment_id_var , cache_ar_index_var , cache_ar_lengths_var = (
655
- self ._get_ar_cache_vars (
656
- batch , key_heads , value_heads , key_head_size , value_head_size , common_types .MODEL_MODE_AUTOREGRESSIVE
657
- )
642
+ self ._get_ar_cache_vars (batch , key_heads , value_heads , key_head_size , value_head_size , MODEL_MODE_AUTOREGRESSIVE )
658
643
)
659
644
660
645
self .update_ar_key_value (
@@ -666,7 +651,7 @@ def kv_cache_autoregressive(
666
651
cache_ar_lengths_var .value ,
667
652
use_ragged_attention ,
668
653
)
669
- active_indicator = jnp .zeros ((batch , 1 ), dtype = jnp .int32 ) + common_types . DECODING_ACTIVE_SEQUENCE_INDICATOR
654
+ active_indicator = jnp .zeros ((batch , 1 ), dtype = jnp .int32 ) + DECODING_ACTIVE_SEQUENCE_INDICATOR
670
655
cached_ar_segment_id_var .value = jax .lax .dynamic_update_index_in_dim (
671
656
cached_ar_segment_id_var .value , active_indicator , jnp .squeeze (cache_ar_index_var .value ), 1
672
657
)
@@ -675,7 +660,7 @@ def kv_cache_autoregressive(
675
660
676
661
# The below retrieves the existing prefill cache variables, not creating new ones
677
662
cached_prefill_key_vars , cached_prefill_value_vars , cached_prefill_segment_id_var = self ._get_prefill_cache_vars (
678
- batch , key_heads , value_heads , key_head_size , value_head_size , common_types . MODEL_MODE_AUTOREGRESSIVE
663
+ batch , key_heads , value_heads , key_head_size , value_head_size , MODEL_MODE_AUTOREGRESSIVE
679
664
)
680
665
681
666
cached_prefill = (
@@ -719,12 +704,12 @@ def __call__(
719
704
two tuples of (k, v, decoder_segments) -- either can be Nones
720
705
721
706
"""
722
- if model_mode == common_types . MODEL_MODE_PREFILL :
707
+ if model_mode == MODEL_MODE_PREFILL :
723
708
if self .use_chunked_prefill :
724
709
return self .kv_cache_chunked_prefill (key , value , decoder_segment_ids , previous_chunk ), None
725
710
else :
726
711
return self .kv_cache_prefill (key , value , decoder_segment_ids ), None
727
- elif model_mode == common_types . MODEL_MODE_AUTOREGRESSIVE :
712
+ elif model_mode == MODEL_MODE_AUTOREGRESSIVE :
728
713
return self .kv_cache_autoregressive (key , value , use_ragged_attention )
729
714
else :
730
715
raise ValueError (f"Model Mode isn't supported! { model_mode = } " )
@@ -736,13 +721,13 @@ class MlaKVCache(KVCache):
736
721
prefill_cache_logical_axis_names : AxisNames = (
737
722
CACHE_BATCH_PREFILL ,
738
723
CACHE_SEQUENCE ,
739
- common_types . CACHE_HEADS_NONE ,
724
+ CACHE_HEADS_NONE ,
740
725
CACHE_KV ,
741
726
)
742
727
cache_logical_axis_names : AxisNames = (
743
728
CACHE_BATCH ,
744
729
CACHE_SEQUENCE ,
745
- common_types . CACHE_HEADS_NONE ,
730
+ CACHE_HEADS_NONE ,
746
731
CACHE_KV ,
747
732
)
748
733
@@ -767,7 +752,7 @@ def __call__(
767
752
Optional [Tuple [Array , Array , Array ]],
768
753
Optional [Tuple [Array , Array , Array , Array ]],
769
754
]:
770
- assert model_mode != common_types . MODEL_MODE_TRAIN , "incorrectly updating kvcache in train mode."
755
+ assert model_mode != MODEL_MODE_TRAIN , "incorrectly updating kvcache in train mode."
771
756
assert self .kv_quant is None , "kvcache quantization not supported with mla."
772
757
key_latent = self .key_latent_add_head_dim (key_latent )
773
758
prefill_cache , ar_cache = super ().__call__ (key_latent , key_rope , decoder_segment_ids , model_mode )
0 commit comments