Skip to content

Commit 5312ae5

Browse files
authored
Merge branch 'main' into bvandermoon-nnx-dev
2 parents fca144b + b373b9b commit 5312ae5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+378
-700
lines changed

.github/CODEOWNERS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# Changes in this file should match with requiredReviewers in file .github/workflows/AddLabel.yml
2-
* @gobbleturk @khatwanimohit @bvandermoon @vipannalla @RissyRan @richjames0 @rni418 @gagika @shralex @yangyuwei @SurbhiJainUSC @hengtaoguo @A9isha @wang2yn84 @wyzhang @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis
2+
* @gobbleturk @khatwanimohit @bvandermoon @vipannalla @RissyRan @richjames0 @rni418 @gagika @shralex @yangyuwei @SurbhiJainUSC @hengtaoguo @A9isha @wang2yn84 @wyzhang @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @aireenmei

.github/workflows/AddLabel.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ jobs:
7474
jrplatin: "",
7575
patemotter: "",
7676
lumosis: "",
77+
aireenmei: "",
7778
}
7879
const reviews = await github.rest.pulls.listReviews({
7980
owner,

.github/workflows/CPUTests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ jobs:
3636
pytype --jobs auto --disable 'import-error,late-directive,wrong-arg-types,module-attr,unsupported-operands' MaxText/ || true
3737
- name: Analysing the code with pylint in Maxtext/
3838
run: |
39-
pylint --verbose --msg-template='[{abspath}] {msg_id}:{line:3d},{column}: {obj}: {msg}' --disable R0401,R1701,R1703,R1710,R1711,R1735,R0917,R1714,R1716,R1719,R1721,R1728,R1728,W0102,W0107,W0201,W0212,W0221,W0223,W0237,W0404,W0611,W0612,W0613,W0621,W0622,W0631,W0707,W0718,W1201,W1203,W1309,W1514,W4901 MaxText/ && \
39+
pylint --verbose --msg-template='[{abspath}] {msg_id}:{line:3d},{column}: {obj}: {msg}' --disable R0401,R0917,W0102,W0107,W0201,W0212,W0221,W0223,W0237,W0404,W0611,W0612,W0613,W0621,W0622,W0631,W0707,W0718,W1201,W1203,W1309,W1514,W4901 MaxText/ && \
4040
echo 'Maxtext PyLint check successful' || { echo \
4141
'PyLint check has failed. Please run bash code_style.sh to fix issues'; exit 20; }
4242
- name: Analysing the code with pylint in pedagogical_examples/

MaxText/common_types.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,15 @@
1818

1919
import numpy as np
2020

21-
import jax
2221
import jax.numpy as jnp
2322

24-
from flax.linen import partitioning
25-
2623
Config = Any
2724

2825
Array = jnp.ndarray
2926
PRNGKey = jnp.ndarray
3027
DType = jnp.dtype
3128
Shape = Sequence[int]
3229

33-
Mesh = jax.sharding.Mesh
34-
ScanIn = partitioning.ScanIn
35-
3630
AxisNames = tuple[str, ...]
3731
AxisIdxes = tuple[int, ...]
3832

MaxText/experimental/rl/grpo_trainer.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@
1414
limitations under the License.
1515
"""
1616

17-
from collections.abc import Callable
18-
19-
from MaxText.common_types import Array
20-
2117
# pylint: disable=g-bad-todo, abstract-method, consider-using-with, attribute-error
2218
"""
2319
This script implements Group Relative Policy Optimization (GRPO) training
@@ -32,6 +28,7 @@
3228
import functools
3329
import queue
3430
from typing import Sequence
31+
from collections.abc import Callable
3532

3633
from absl import app
3734

@@ -56,17 +53,17 @@
5653
import transformers
5754

5855
from MaxText import checkpointing
56+
from MaxText import max_logging
5957
from MaxText import max_utils
58+
from MaxText import maxengine
6059
from MaxText import maxtext_utils
61-
from MaxText import max_logging
6260
from MaxText import profiler
6361
from MaxText import pyconfig
64-
from MaxText import maxengine
65-
from MaxText.metric_logger import MetricLogger
66-
from MaxText.vertex_tensorboard import VertexTensorboardManager
62+
from MaxText.common_types import Array
6763
from MaxText.experimental.rl import grpo_input_pipeline
68-
from MaxText.layers import models
6964
from MaxText.gcp_workload_monitor import GCPWorkloadMonitor
65+
from MaxText.layers import models
66+
from MaxText.metric_logger import MetricLogger
7067
from MaxText.train import (
7168
validate_train_config,
7269
get_first_step,
@@ -78,6 +75,7 @@
7875
check_example_batch,
7976
setup_mesh_and_model,
8077
)
78+
from MaxText.vertex_tensorboard import VertexTensorboardManager
8179

8280
# pylint: disable=too-many-positional-arguments
8381

MaxText/inference/kvcache.py

Lines changed: 20 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -23,30 +23,17 @@
2323

2424
from aqt.jax.v2 import aqt_tensor
2525
from aqt.jax.v2 import config as aqt_config
26+
from aqt.jax.v2.aqt_tensor import QTensor as KVTensor
2627
from aqt.jax.v2.flax import aqt_flax
2728

28-
from MaxText import common_types
29+
from MaxText.common_types import Array, AxisNames, AxisIdxes, Config, CACHE_BATCH_PREFILL, DType, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, MODEL_MODE_AUTOREGRESSIVE, CACHE_HEADS_NONE, DECODING_ACTIVE_SEQUENCE_INDICATOR
30+
from MaxText.common_types import CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV, CACHE_SCALE_BATCH, CACHE_SCALE_SEQUENCE, CACHE_SCALE_HEADS, CACHE_SCALE_KV
2931

30-
Array = common_types.Array
31-
AxisNames = common_types.AxisNames
32-
AxisIdxes = common_types.AxisIdxes
33-
Config = common_types.Config
34-
KVTensor = aqt_tensor.QTensor
3532

3633
MAX_INT8 = 127.5
3734
MAX_INT4 = 7.5
3835
E4M3_MAX = jnp.finfo(jnp.float8_e4m3fn).max.astype(jnp.float32)
3936

40-
CACHE_BATCH_PREFILL = common_types.CACHE_BATCH_PREFILL
41-
CACHE_BATCH = common_types.CACHE_BATCH
42-
CACHE_SEQUENCE = common_types.CACHE_SEQUENCE
43-
CACHE_HEADS = common_types.CACHE_HEADS
44-
CACHE_KV = common_types.CACHE_KV
45-
CACHE_SCALE_BATCH = common_types.CACHE_SCALE_BATCH
46-
CACHE_SCALE_SEQUENCE = common_types.CACHE_SCALE_SEQUENCE
47-
CACHE_SCALE_HEADS = common_types.CACHE_SCALE_HEADS
48-
CACHE_SCALE_KV = common_types.CACHE_SCALE_KV
49-
5037

5138
def reverse_transpose(transposed_array, transpose_axis_order):
5239
return jax.numpy.moveaxis(transposed_array, (0, 1, 2, 3), transpose_axis_order)
@@ -167,7 +154,7 @@ class KVCache(nn.Module):
167154

168155
max_prefill_length: int
169156
max_target_length: int
170-
dtype: common_types.DType
157+
dtype: DType
171158
kv_quant: Optional[KVQuant] = None
172159
prefill_cache_logical_axis_names: AxisNames = (CACHE_BATCH_PREFILL, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV)
173160
cache_logical_axis_names: AxisNames = (CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV)
@@ -194,7 +181,7 @@ def _get_prefill_cache_vars(self, batch, key_heads, value_heads, key_head_size,
194181
cache_length = self.max_prefill_length
195182
dtype = self._get_cached_kv_dtype()
196183

197-
if model_mode == common_types.MODEL_MODE_PREFILL:
184+
if model_mode == MODEL_MODE_PREFILL:
198185
cache_logical_axis_names = self.prefill_cache_logical_axis_names
199186
else:
200187
cache_logical_axis_names = self.cache_logical_axis_names
@@ -219,7 +206,7 @@ def _get_prefill_cache_vars(self, batch, key_heads, value_heads, key_head_size,
219206
cache_shape_value,
220207
dtype,
221208
)
222-
if model_mode == common_types.MODEL_MODE_PREFILL:
209+
if model_mode == MODEL_MODE_PREFILL:
223210
segment_id_axis_names = (CACHE_BATCH_PREFILL, CACHE_SEQUENCE)
224211
else:
225212
segment_id_axis_names = (CACHE_BATCH, CACHE_SEQUENCE)
@@ -274,7 +261,7 @@ def _get_ar_cache_vars(self, batch, key_heads, value_heads, key_head_size, value
274261
)
275262
cache_length = self.max_target_length - self.max_prefill_length
276263

277-
if model_mode == common_types.MODEL_MODE_PREFILL:
264+
if model_mode == MODEL_MODE_PREFILL:
278265
cache_logical_axis_names = self.prefill_cache_logical_axis_names
279266
else:
280267
cache_logical_axis_names = self.cache_logical_axis_names
@@ -311,7 +298,7 @@ def _get_ar_cache_vars(self, batch, key_heads, value_heads, key_head_size, value
311298
cache_axis_names,
312299
)
313300

314-
if model_mode == common_types.MODEL_MODE_PREFILL:
301+
if model_mode == MODEL_MODE_PREFILL:
315302
segment_id_axis_names = (CACHE_BATCH_PREFILL, CACHE_SEQUENCE)
316303
else:
317304
segment_id_axis_names = (CACHE_BATCH, CACHE_SEQUENCE)
@@ -401,11 +388,11 @@ def kv_cache_chunked_prefill(
401388
next_pos = previous_chunk.shape[1]
402389

403390
cached_prefill_key_vars, cached_prefill_value_vars, cached_prefill_segment_id_var = self._get_prefill_cache_vars(
404-
batch, key_heads, value_heads, key_head_size, value_head_size, common_types.MODEL_MODE_PREFILL
391+
batch, key_heads, value_heads, key_head_size, value_head_size, MODEL_MODE_PREFILL
405392
)
406393
# TODO: Find a way to not enable the ar cache for prefill mode.
407394
_ = self._get_ar_cache_vars(
408-
batch, key_heads, value_heads, key_head_size, value_head_size, common_types.MODEL_MODE_PREFILL
395+
batch, key_heads, value_heads, key_head_size, value_head_size, MODEL_MODE_PREFILL
409396
) # initialize it now
410397

411398
key_shaped_for_cache = jnp.transpose(key, self.prefill_cache_axis_order)
@@ -488,11 +475,11 @@ def kv_cache_prefill(
488475
assert key.dtype == value.dtype, "Key and Value Dtypes should match."
489476

490477
cached_prefill_key_vars, cached_prefill_value_vars, cached_prefill_segment_id_var = self._get_prefill_cache_vars(
491-
batch, key_heads, value_heads, key_head_size, value_head_size, common_types.MODEL_MODE_PREFILL
478+
batch, key_heads, value_heads, key_head_size, value_head_size, MODEL_MODE_PREFILL
492479
)
493480
# TODO: Find a way to not enable the ar cache for prefill mode.
494481
_ = self._get_ar_cache_vars(
495-
batch, key_heads, value_heads, key_head_size, value_head_size, common_types.MODEL_MODE_PREFILL
482+
batch, key_heads, value_heads, key_head_size, value_head_size, MODEL_MODE_PREFILL
496483
) # initialize it now
497484

498485
key_shaped_for_cache = jnp.transpose(key, self.prefill_cache_axis_order)
@@ -652,9 +639,7 @@ def kv_cache_autoregressive(
652639
raise ValueError(f"Sequence length should be 1 during autoregression, got {sequence=}")
653640

654641
cached_ar_key_vars, cached_ar_value_vars, cached_ar_segment_id_var, cache_ar_index_var, cache_ar_lengths_var = (
655-
self._get_ar_cache_vars(
656-
batch, key_heads, value_heads, key_head_size, value_head_size, common_types.MODEL_MODE_AUTOREGRESSIVE
657-
)
642+
self._get_ar_cache_vars(batch, key_heads, value_heads, key_head_size, value_head_size, MODEL_MODE_AUTOREGRESSIVE)
658643
)
659644

660645
self.update_ar_key_value(
@@ -666,7 +651,7 @@ def kv_cache_autoregressive(
666651
cache_ar_lengths_var.value,
667652
use_ragged_attention,
668653
)
669-
active_indicator = jnp.zeros((batch, 1), dtype=jnp.int32) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR
654+
active_indicator = jnp.zeros((batch, 1), dtype=jnp.int32) + DECODING_ACTIVE_SEQUENCE_INDICATOR
670655
cached_ar_segment_id_var.value = jax.lax.dynamic_update_index_in_dim(
671656
cached_ar_segment_id_var.value, active_indicator, jnp.squeeze(cache_ar_index_var.value), 1
672657
)
@@ -675,7 +660,7 @@ def kv_cache_autoregressive(
675660

676661
# The below retrieves the existing prefill cache variables, not creating new ones
677662
cached_prefill_key_vars, cached_prefill_value_vars, cached_prefill_segment_id_var = self._get_prefill_cache_vars(
678-
batch, key_heads, value_heads, key_head_size, value_head_size, common_types.MODEL_MODE_AUTOREGRESSIVE
663+
batch, key_heads, value_heads, key_head_size, value_head_size, MODEL_MODE_AUTOREGRESSIVE
679664
)
680665

681666
cached_prefill = (
@@ -719,12 +704,12 @@ def __call__(
719704
two tuples of (k, v, decoder_segments) -- either can be Nones
720705
721706
"""
722-
if model_mode == common_types.MODEL_MODE_PREFILL:
707+
if model_mode == MODEL_MODE_PREFILL:
723708
if self.use_chunked_prefill:
724709
return self.kv_cache_chunked_prefill(key, value, decoder_segment_ids, previous_chunk), None
725710
else:
726711
return self.kv_cache_prefill(key, value, decoder_segment_ids), None
727-
elif model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE:
712+
elif model_mode == MODEL_MODE_AUTOREGRESSIVE:
728713
return self.kv_cache_autoregressive(key, value, use_ragged_attention)
729714
else:
730715
raise ValueError(f"Model Mode isn't supported! {model_mode=}")
@@ -736,13 +721,13 @@ class MlaKVCache(KVCache):
736721
prefill_cache_logical_axis_names: AxisNames = (
737722
CACHE_BATCH_PREFILL,
738723
CACHE_SEQUENCE,
739-
common_types.CACHE_HEADS_NONE,
724+
CACHE_HEADS_NONE,
740725
CACHE_KV,
741726
)
742727
cache_logical_axis_names: AxisNames = (
743728
CACHE_BATCH,
744729
CACHE_SEQUENCE,
745-
common_types.CACHE_HEADS_NONE,
730+
CACHE_HEADS_NONE,
746731
CACHE_KV,
747732
)
748733

@@ -767,7 +752,7 @@ def __call__(
767752
Optional[Tuple[Array, Array, Array]],
768753
Optional[Tuple[Array, Array, Array, Array]],
769754
]:
770-
assert model_mode != common_types.MODEL_MODE_TRAIN, "incorrectly updating kvcache in train mode."
755+
assert model_mode != MODEL_MODE_TRAIN, "incorrectly updating kvcache in train mode."
771756
assert self.kv_quant is None, "kvcache quantization not supported with mla."
772757
key_latent = self.key_latent_add_head_dim(key_latent)
773758
prefill_cache, ar_cache = super().__call__(key_latent, key_rope, decoder_segment_ids, model_mode)

MaxText/inference/page_manager.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,8 @@
3232

3333
from jaxtyping import Array, Integer, Bool
3434

35-
from MaxText import common_types
35+
from MaxText.common_types import Config
3636

37-
Config = common_types.Config
3837

3938
# Aliases using <Dims><Type><Rank>d convention
4039
# We use string names for dimensions as they are symbolic within the type hints.

MaxText/inference/paged_attention.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,33 +21,17 @@
2121
from typing import Optional
2222

2323
import jax.numpy as jnp
24-
from jax.experimental import shard_map
24+
from jax.experimental.shard_map import shard_map
2525
from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention_kernel
2626
from jax.sharding import PartitionSpec as P
27+
from jax.sharding import Mesh
2728

2829
from flax import linen as nn
2930

30-
from MaxText import common_types
3131
from MaxText.inference import page_manager
3232
from MaxText.inference import paged_attention_kernel_v2
33+
from MaxText.common_types import Array, DType, AxisNames, BATCH, LENGTH, HEAD, D_KV, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE
3334

34-
# pytype: disable=attribute-error
35-
36-
Mesh = common_types.Mesh
37-
38-
Array = common_types.Array
39-
Config = common_types.Config
40-
DType = common_types.DType
41-
Mesh = common_types.Mesh
42-
PRNGKey = common_types.PRNGKey
43-
44-
AxisNames = common_types.AxisNames
45-
BATCH = common_types.BATCH
46-
LENGTH = common_types.LENGTH
47-
HEAD = common_types.HEAD
48-
D_KV = common_types.D_KV
49-
50-
shard_map = shard_map.shard_map
5135
_use_kernel_v2 = False
5236

5337

@@ -255,16 +239,18 @@ def __call__(
255239
key_pages_var, value_pages_var = self.init_or_get_kv_pages(model_mode)
256240

257241
# update kv pages and call page attention kernel
258-
if model_mode == common_types.MODEL_MODE_PREFILL:
242+
if model_mode == MODEL_MODE_PREFILL:
259243
self.update_prefill_step_pages(key_pages_var, value_pages_var, key, value, slot, page_state)
260244
if _use_kernel_v2:
261245
return self.paged_attention_v2_prefill(query, key_pages_var, value_pages_var, page_state), None, None
262246
return self.paged_dot_product_attention_with_max_and_sum(query, key, value)
263-
elif model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE:
247+
elif model_mode == MODEL_MODE_AUTOREGRESSIVE and page_state is not None:
264248
self.update_decode_step_pages(key_pages_var, value_pages_var, key, value, page_state)
265249
if _use_kernel_v2:
266250
return self.paged_attention_v2_decode(query, key_pages_var, value_pages_var, page_state), None, None
267251
return self.paged_attention_v1_decode(query, key_pages_var, value_pages_var, page_state), None, None
252+
else:
253+
raise NotImplementedError(model_mode)
268254

269255
def update_prefill_step_pages(
270256
self,

MaxText/kernels/megablox/gmm.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,22 @@
1717
# pylint: disable=too-many-positional-arguments, unnecessary-lambda-assignment
1818

1919
from collections.abc import Callable
20+
from functools import partial
21+
from typing import Any, Optional, Literal
2022
import dataclasses
2123
import functools
22-
from typing import Any, Optional, Literal
2324

24-
import jax
25-
import jax.numpy as jnp
2625
from jax import lax
2726
from jax.experimental import pallas as pl
2827
from jax.experimental.pallas import tpu as pltpu
28+
import jax
29+
import jax.numpy as jnp
2930

3031
from aqt.jax.v2 import pallas as aqt_pl
31-
from aqt.jax.v2 import aqt_tensor
32+
from aqt.jax.v2.aqt_tensor import QTensor
3233

3334
from MaxText.kernels.megablox import common
3435

35-
QTensor = aqt_tensor.QTensor
36-
partial = functools.partial
37-
3836

3937
def _validate_args(
4038
*,

MaxText/kernels/ragged_attention.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,8 @@
2525
from jax.experimental import pallas as pl
2626
from jax.experimental.pallas import tpu as pltpu
2727
import jax.numpy as jnp
28-
from jax.experimental import shard_map
2928

30-
from MaxText import common_types
31-
32-
33-
BATCH = common_types.BATCH
34-
DEFAULT_MASK_VALUE = common_types.DEFAULT_MASK_VALUE
35-
shard_map = shard_map.shard_map
29+
from MaxText.common_types import DEFAULT_MASK_VALUE
3630

3731

3832
def get_mha_cost_estimate(shape_dtype):

0 commit comments

Comments
 (0)