Skip to content

Convert DenseGeneral to NNX #1710

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CPUTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
pytype --jobs auto --disable 'import-error,late-directive,wrong-arg-types,module-attr,unsupported-operands' MaxText/ || true
- name: Analysing the code with pylint in Maxtext/
run: |
pylint --verbose --msg-template='[{abspath}] {msg_id}:{line:3d},{column}: {obj}: {msg}' --disable E0102,E0606,E0611,E1102,E1111,E1120,E1121,E1123,E1135,E1136,R0401,R1701,R1703,R1710,R1711,R1735,R0917,R1714,R1716,R1719,R1721,R1728,R1728,W0102,W0107,W0201,W0212,W0221,W0223,W0237,W0404,W0611,W0612,W0613,W0621,W0622,W0631,W0707,W0718,W1201,W1203,W1309,W1514,W4901 MaxText/ && \
pylint --verbose --msg-template='[{abspath}] {msg_id}:{line:3d},{column}: {obj}: {msg}' --disable R0401,R1701,R1703,R1710,R1711,R1735,R0917,R1714,R1716,R1719,R1721,R1728,R1728,W0102,W0107,W0201,W0212,W0221,W0223,W0237,W0404,W0611,W0612,W0613,W0621,W0622,W0631,W0707,W0718,W1201,W1203,W1309,W1514,W4901 MaxText/ && \
echo 'Maxtext PyLint check successful' || { echo \
'PyLint check has failed. Please run bash code_style.sh to fix issues'; exit 20; }
- name: Analysing the code with pylint in pedagogical_examples/
Expand Down
10 changes: 5 additions & 5 deletions .github/workflows/RunTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jobs:
device_name: a100-40gb-4
build_mode: stable_stack
base_image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:latest

cpu_unit_tests:
needs: tpu_image
uses: ./.github/workflows/run_tests_internal.yml
Expand All @@ -80,7 +80,7 @@ jobs:
with:
device_type: tpu
device_name: v4-8
# cloud_runner: linux-x86-ct4p-240-4tpu
cloud_runner: linux-x86-ct4p-240-4tpu
pytest_marker: 'not cpu_only and not gpu_only and not integration_test'
xla_python_client_mem_fraction: 0.75
tf_force_gpu_allow_growth: false
Expand All @@ -92,7 +92,7 @@ jobs:
with:
device_type: tpu
device_name: v4-8
# cloud_runner: linux-x86-ct4p-240-4tpu
cloud_runner: linux-x86-ct4p-240-4tpu
pytest_marker: 'not cpu_only and not gpu_only and integration_test'
xla_python_client_mem_fraction: 0.75
tf_force_gpu_allow_growth: false
Expand All @@ -104,7 +104,7 @@ jobs:
with:
device_type: gpu
device_name: a100-40gb-4
# cloud_runner: linux-x86-a2-48-a100-4gpu
cloud_runner: linux-x86-a2-48-a100-4gpu
pytest_marker: 'not cpu_only and not tpu_only and not integration_test'
xla_python_client_mem_fraction: 0.65
tf_force_gpu_allow_growth: true
Expand All @@ -116,7 +116,7 @@ jobs:
with:
device_type: gpu
device_name: a100-40gb-4
# cloud_runner: linux-x86-a2-48-a100-4gpu
cloud_runner: linux-x86-a2-48-a100-4gpu
pytest_marker: 'not cpu_only and not tpu_only and integration_test'
xla_python_client_mem_fraction: 0.65
tf_force_gpu_allow_growth: true
Expand Down
17 changes: 17 additions & 0 deletions MaxText/benchmark_chunked_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,23 @@ def benchmark_prefix_cache(


def prepare_setting(argv: Sequence[str]):
"""
Constructs the necessary components for benchmarking chunked prefill with prefix caching.
q
Args:
argv: The command-line arguments.

Returns:
engine (maxengine.MaxEngine): The MaxEngine instance.
params (Any): The model parameters.
tokens (jax.Array): The input token sequence.
chunked_tokens_list (list[chunked_prefill.ChunkedTokens]): A list of ChunkedTokens objects representing the
input sequence split into chunks.
prefix_caching_hbm_byte (int): The size of the HBM layer in the prefix cache.
prefix_caching_dram_byte (int): The size of the DRAM layer in the prefix cache.
chunk_size (int): The chunk size used for prefilling.
max_prefill_length (int): The maximum length of the prefill sequence.
"""
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
config = pyconfig.initialize(argv)
Expand Down
36 changes: 16 additions & 20 deletions MaxText/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,26 @@
"""Create an Orbax CheckpointManager with specified (Async or not) Checkpointer."""

from typing import Any, Optional, Union

from absl import flags

from etils import epath
from flax.training import train_state

import grain.python as grain
import jax
from MaxText import max_logging
from MaxText.multihost_dataloading import MultiHostDataLoadIterator

import numpy as np

import jax

from flax.training import train_state

import orbax.checkpoint as ocp
import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager
import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager

from MaxText import max_logging
from MaxText.multihost_dataloading import MultiHostDataLoadIterator

# pylint: disable=too-many-positional-arguments

CheckpointManager = ocp.CheckpointManager
Expand Down Expand Up @@ -143,11 +151,7 @@ def create_orbax_emergency_replicator_checkpoint_manager(

def replicator_error_handler(config: Any):
"""Replicator error handler to handle errors in replicator service."""
if (
config.enable_emergency_checkpoint
and config.use_replicator_service
and config.local_checkpoint_directory
):
if config.enable_emergency_checkpoint and config.use_replicator_service and config.local_checkpoint_directory:
local_dir = config.local_checkpoint_directory
replicator_errors_file = f"{local_dir}/replicator.errors"
replicator_failed_file = f"{local_dir}/replicator.failed"
Expand All @@ -156,9 +160,7 @@ def replicator_error_handler(config: Any):
# if the replicator.failed file exists, then we have a fatal error
is_fatal = process_replicator_error_file(replicator_failed_file)
if is_fatal:
raise ValueError(
"Replicator fatal error found in replicator.failed file."
)
raise ValueError("Replicator fatal error found in replicator.failed file.")


def process_replicator_error_file(error_file: str) -> bool:
Expand All @@ -178,21 +180,15 @@ def read_replicator_error_file(error_file: str):
error_data = epath.Path(error_file).read_text()
max_logging.log(f"Contents of replicator error file:\n{error_data}")
except (OSError, ValueError) as e:
max_logging.log(
"replicator_error_handler: Failed to read contents of failed"
f" file: {e}"
)
max_logging.log("replicator_error_handler: Failed to read contents of failed" f" file: {e}")


def cleanup_replicator_error_file(error_file: str):
"""Clean up replicator errors file."""
try:
epath.Path(error_file).unlink()
except (OSError, ValueError) as e:
max_logging.log(
"replicator_error_handler: Failed to remove replicator errors file:"
f" {e}"
)
max_logging.log("replicator_error_handler: Failed to remove replicator errors file:" f" {e}")


def print_save_message(step, async_checkpointing):
Expand Down
6 changes: 4 additions & 2 deletions MaxText/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
import enum
from typing import Any, Sequence

from flax.linen import partitioning
import numpy as np

import jax
import jax.numpy as jnp
import numpy as np

from flax.linen import partitioning

Config = Any

Expand Down
15 changes: 15 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ qkv_proj: 'remat'
out_proj: 'remat'

optimizer_memory_host_offload: False
parameter_memory_host_offload: False
scan_layers: True # We recommend setting this to false when using pipeline parallelism, instead scanning the PP iterations.
param_scan_axis: 1

Expand Down Expand Up @@ -709,3 +710,17 @@ remat_policy_for_vit: "minimal" # Remat policy for multimodal model's vision en
image_size_for_vit: 896 # Default for Gemma3, and should be overwritten by model's config
image_path: "" # Local image path used for decoding


### llama4 multi modal configs
hidden_size_for_vit: 1408
intermediate_size_for_vit: 5632
num_attention_heads_for_vit: 16
num_channels_for_vit: 3
patch_size_for_vit: 14
num_hidden_layers_for_vit: 34
projector_input_dim_for_vit: 4096
projector_output_dim_for_vit: 4096
rope_theta_for_vit: 10000
vision_output_dim_for_vit: 4096
pixel_shuffle_ratio_for_vit: 0.5
projector_dropout_for_vit: 0.0
1 change: 1 addition & 0 deletions MaxText/configs/models/llama4-17b-128e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ temperature_tuning: True
# Chunk attention is used on all RoPE layers
# otherwise, on NoPE layers, use global attention
chunk_attn_window_size: 8192
image_size_for_vit: 336
1 change: 1 addition & 0 deletions MaxText/configs/models/llama4-17b-16e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ temperature_tuning: True
# Chunk attention is used on all RoPE layers
# otherwise, on NoPE layers, use global attention
chunk_attn_window_size: 8192
image_size_for_vit: 336
17 changes: 9 additions & 8 deletions MaxText/convert_gemma2_chkpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,26 @@
Convert orbax Gemma checkpoint to MaxText compatible checkpoint.
"""

import jax
import jax.numpy as jnp
import numpy as np

jax.config.update("jax_platform_name", "cpu")
from typing import Any
import argparse
import copy
from flax.training import train_state

from typing import Any
import sys

import numpy as np

import jax
import jax.numpy as jnp

from flax.training import train_state

import orbax

from MaxText import max_logging
from MaxText import checkpointing
from MaxText.train import save_checkpoint

jax.config.update("jax_platform_name", "cpu")

Params = dict[str, Any]


Expand Down
18 changes: 10 additions & 8 deletions MaxText/convert_gemma3_chkpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,27 @@
limitations under the License.
"""

import jax
import jax.numpy as jnp
import numpy as np

jax.config.update("jax_platform_name", "cpu")
from typing import Any
import argparse
import copy
from flax.training import train_state

from typing import Any
import sys

import numpy as np

import jax
import jax.numpy as jnp

from flax.training import train_state

import orbax

from MaxText import checkpointing
from MaxText import max_logging
from MaxText.train import save_checkpoint

jax.config.update("jax_platform_name", "cpu")


Params = dict[str, Any]


Expand Down
19 changes: 10 additions & 9 deletions MaxText/convert_gemma_chkpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,26 @@
Convert orbax Gemma checkpoint to MaxText compatible checkpoint.
"""

import jax
import jax.numpy as jnp
import numpy as np

jax.config.update("jax_platform_name", "cpu")
from typing import Any
import argparse
import copy
from flax.training import train_state

from typing import Any
import sys

import numpy as np

import jax
import jax.numpy as jnp

from flax.training import train_state

import orbax

from MaxText import checkpointing
from MaxText import max_logging
from MaxText.train import save_checkpoint

jax.config.update("jax_platform_name", "cpu")

Params = dict[str, Any]


Expand All @@ -57,7 +58,7 @@ def main(raw_args=None) -> None:
parser.add_argument("--model_size", type=str, required=True)
args = parser.parse_args(raw_args)
if args.model_size not in ("2b", "7b", "9b"):
raise NotImplementedError
raise NotImplementedError(args.model_size)

print("Loading checkpoint")
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
Expand Down
10 changes: 8 additions & 2 deletions MaxText/elastic_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,22 @@
import queue

from absl import app

from cloud_tpu_diagnostics import diagnostic
from cloud_tpu_diagnostics.configuration import debug_configuration
from cloud_tpu_diagnostics.configuration import diagnostic_configuration
from cloud_tpu_diagnostics.configuration import stack_trace_configuration
from flax.linen import partitioning as nn_partitioning

import jax

from flax.linen import partitioning as nn_partitioning

from ml_goodput_measurement import monitoring

import pathwaysutils
from pathwaysutils.elastic import manager
from pathwaysutils.debug import timing

import tensorflow as tf

from MaxText import checkpointing
Expand Down Expand Up @@ -267,7 +273,7 @@ def train_loop(config, elastic_manager, state=None):
block=True,
)

input_data_shardings = maxtext_utils.get_input_data_shardings(config, mesh)
input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh)
# Using while loop instead of a for loop because with elasticity
# the step is restored back to the latest snapshot when a slice is lost
while step < config.steps:
Expand Down
10 changes: 7 additions & 3 deletions MaxText/experimental/rl/grpo_input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,19 @@
limitations under the License.
"""

import functools
from collections.abc import Iterable

import numpy as np

import jax
from jax.sharding import Mesh

import numpy as np
import functools
import datasets

import transformers

import grain.python as grain
from collections.abc import Iterable

from MaxText.input_pipeline import input_pipeline_interface
from MaxText.input_pipeline import _input_pipeline_utils
Expand Down
Loading