Skip to content

[Draft] Add Support for Headless Mode Recipe #1708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,4 @@ dmypy.json

# Intellij IDEA setting folder
.idea

5 changes: 3 additions & 2 deletions MaxText/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
cloud_logger = None # pytype: disable=attribute-error



def create_orbax_checkpoint_manager(
checkpoint_dir: str,
enable_checkpointing: bool,
Expand Down Expand Up @@ -373,10 +374,10 @@ def setup_checkpoint_logger(config) -> Any | None: # pytype: disable=attribute-
max_logging.log("Setting up checkpoint logger...")
if config.enable_checkpoint_cloud_logger:
logger_name = f"goodput_{config.run_name}"
options = cloud_logger.CloudLoggerOptions(
options = ocp.logging.CloudLoggerOptions(
job_name=config.run_name, logger_name=logger_name
) # pytype: disable=attribute-error
orbax_cloud_logger = cloud_logger.CloudLogger(options=options) # pytype: disable=attribute-error
orbax_cloud_logger = ocp.logging.CloudLogger(options=options) # pytype: disable=attribute-error
max_logging.log("Successfully set up checkpoint cloud logger.")
return orbax_cloud_logger

Expand Down
3 changes: 3 additions & 0 deletions benchmarks/maxtext_trillium_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"checkpoint_storage_use_ocdbt": False,
"checkpoint_storage_use_zarr3": False,
"enable_pathways_goodput": True,
"enable_goodput_recording": True,
"enable_single_controller": True,
"metrics_file": "metrics.txt",
"goodput_upload_interval_seconds": 30,
Expand All @@ -43,6 +44,7 @@
"async_checkpointing": True,
"checkpoint_period": 100,
"enable_checkpoint_cloud_logger": True,
"enable_goodput_recording": True,
}

# The set of tuning params required for short-running pathways jobs.
Expand All @@ -51,6 +53,7 @@
"async_checkpointing": True,
"checkpoint_period": 20,
"enable_checkpoint_cloud_logger": True,
"enable_goodput_recording": True,
}


Expand Down
231 changes: 142 additions & 89 deletions benchmarks/maxtext_xpk_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,13 @@
import threading
import time
from typing import Optional, List
import sys

import omegaconf

parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)

import benchmarks.maxtext_trillium_model_configs as model_configs
from benchmarks.command_utils import run_command_with_updates
import benchmarks.xla_flags_library as xla_flags
Expand Down Expand Up @@ -78,6 +82,7 @@ class PathwaysConfig:
server_flags: str = ''
proxy_flags: str = ''
worker_flags: str = ''
headless: bool = False


# TODO(@vbarr): Split out parameters related to XPK workload and a General workload
Expand Down Expand Up @@ -426,130 +431,174 @@ def build_user_command(
return command


def _get_pathways_proxy_flags(wl_config: WorkloadConfig):
"""Get the pathways proxy flags for the workload and removes any extras."""
# Add in the xla flags alongside the proxy flags from the pathways config.
pw_config = wl_config.pathways_config

# Get proxy and xla flag string from model config
proxy_flags_string = pw_config.proxy_flags
xla_flags_string = wl_config.model.xla_flags
def _apply_model_specific_proxy_modifications(
current_proxy_flags_list: list[str], model_config: model_configs.MaxTextModel
) -> list[str]:
"""
Applies XLA flags and other model-specific modifications to the proxy flags list.
This function assumes model_config is not None.
"""
# Make a copy to avoid modifying the original list passed in certain scenarios
proxy_flags = list(current_proxy_flags_list)

# Split both proxy_flags_string and xla_flags_string into lists of flags
proxy_flags_list = proxy_flags_string.strip().split()
# Add in the xla flags from the model config
xla_flags_string = model_config.xla_flags
xla_flags_list = xla_flags_string.strip().split()

# Combine the two lists of flags into a single list
proxy_flags = proxy_flags_list + xla_flags_list
proxy_flags.extend(xla_flags_list) # Combine with existing proxy flags

# Remove the flags that are specified to be removed.
if (
wl_config.model.pathways_xla_flag_options
and xla_flags.REMOVE in wl_config.model.pathways_xla_flag_options
model_config.pathways_xla_flag_options
and xla_flags.REMOVE in model_config.pathways_xla_flag_options
):
flags_to_remove = wl_config.model.pathways_xla_flag_options[
xla_flags.REMOVE
]
flags_to_remove_config = model_config.pathways_xla_flag_options[xla_flags.REMOVE]
# Ensure flags_to_remove_config is an iterable of strings (e.g., list or set)
# If it's a single string of space-separated flags, split it.
if isinstance(flags_to_remove_config, str):
flags_to_remove_set = set(flags_to_remove_config.strip().split())
elif isinstance(flags_to_remove_config, list):
flags_to_remove_set = set(flags_to_remove_config)
else: # Assuming it's already a set or other iterable
flags_to_remove_set = set(flags_to_remove_config)

updated_proxy_flags = []
for flag in proxy_flags:
if flag not in flags_to_remove:
if flag not in flags_to_remove_set:
updated_proxy_flags.append(flag)
proxy_flags = updated_proxy_flags

# Add the flags that are specified to be added.
# Add the flags that are specified to be added for proxy.
if (
model_config.pathways_xla_flag_options
and xla_flags.ADD_PROXY in model_config.pathways_xla_flag_options
):
flags_to_add_str = model_config.pathways_xla_flag_options[xla_flags.ADD_PROXY]
# Original code did .append(), suggesting flags_to_add_str might be a single conceptual flag,
# but it could also be a string of multiple space-separated flags.
# .extend() with .split() is safer if it might be multiple flags.
if isinstance(flags_to_add_str, str):
proxy_flags.extend(flags_to_add_str.strip().split())
elif isinstance(flags_to_add_str, list):
proxy_flags.extend(flags_to_add_str)

return [flag for flag in proxy_flags if flag]


def _apply_model_specific_add_flags(
current_flags_string: str, model_config: model_configs.MaxTextModel, add_flag_key: str
) -> str:
"""
Adds flags specified in model_config.pathways_xla_flag_options for the given add_flag_key.
This function assumes model_config is not None.
"""
flags_to_add_str = ""
if (
wl_config.model.pathways_xla_flag_options
and xla_flags.ADD_PROXY in wl_config.model.pathways_xla_flag_options
model_config.pathways_xla_flag_options
and add_flag_key in model_config.pathways_xla_flag_options
):
flags_to_add = wl_config.model.pathways_xla_flag_options[
xla_flags.ADD_PROXY
]
proxy_flags.append(flags_to_add)
flags_to_add_str = model_config.pathways_xla_flag_options[add_flag_key]

# Join the list of flags back into a single string, space-separated
return ' '.join(proxy_flags)
if flags_to_add_str: # Only add if there's something to add
# Ensure space separation if current_flags_string is not empty
if current_flags_string:
return f"{current_flags_string} {flags_to_add_str}".strip()
return flags_to_add_str.strip()
return current_flags_string # Return original if nothing to add


def _get_pathways_worker_flags(wl_config: WorkloadConfig):
"""Get the pathways worker flags for the workload and removes any extras."""
# Add in the xla flags alongside the worker flags from the pathways config.
def _get_pathways_proxy_flags(wl_config: WorkloadConfig) -> str:
"""Get the pathways proxy flags for the workload."""
pw_config = wl_config.pathways_config

# Get worker and xla flag string from model config
worker_flags = pw_config.worker_flags
# Get base proxy flags string from pathways config
proxy_flags_string = pw_config.proxy_flags
proxy_flags_list = [flag for flag in proxy_flags_string.strip().split() if flag]

# Add the flags that are specified to be added.
if (
wl_config.model.pathways_xla_flag_options
and xla_flags.ADD_WORKER in wl_config.model.pathways_xla_flag_options
):
flags_to_add = wl_config.model.pathways_xla_flag_options[
xla_flags.ADD_WORKER
]
worker_flags += flags_to_add
# Check for headless mode and if model config exists
is_pathways_headless_enabled = pw_config and pw_config.headless
if not is_pathways_headless_enabled and wl_config.model:
proxy_flags_list = _apply_model_specific_proxy_modifications(
proxy_flags_list, wl_config.model
)

# Join the list of flags back into a single string, space-separated
return worker_flags
return ' '.join(proxy_flags_list)


def _get_pathways_server_flags(wl_config: WorkloadConfig):
"""Get the pathways server flags for the workload and removes any extras."""
# Add in the xla flags alongside the server flags from the pathways config.
def _get_pathways_worker_flags(wl_config: WorkloadConfig) -> str:
"""Get the pathways worker flags for the workload."""
pw_config = wl_config.pathways_config
worker_flags_string = pw_config.worker_flags.strip()

# Get server and xla flag string from model config
server_flags = pw_config.server_flags
is_pathways_headless_enabled = pw_config and pw_config.headless
if not is_pathways_headless_enabled and wl_config.model:
worker_flags_string = _apply_model_specific_add_flags(
worker_flags_string, wl_config.model, xla_flags.ADD_WORKER
)

# Add the flags that are specified to be added.
if (
wl_config.model.pathways_xla_flag_options
and xla_flags.ADD_SERVER in wl_config.model.pathways_xla_flag_options
):
flags_to_add = wl_config.model.pathways_xla_flag_options[
xla_flags.ADD_SERVER
]
server_flags += flags_to_add
return worker_flags_string

# Join the list of flags back into a single string, space-separated
return server_flags
def _get_pathways_server_flags(wl_config: WorkloadConfig) -> str:
"""Get the pathways server flags for the workload."""
pw_config = wl_config.pathways_config
server_flags_string = pw_config.server_flags.strip()

is_pathways_headless_enabled = pw_config and pw_config.headless
if not is_pathways_headless_enabled and wl_config.model:
server_flags_string = _apply_model_specific_add_flags(
server_flags_string, wl_config.model, xla_flags.ADD_SERVER
)

return server_flags_string

def _get_pathways_specific_flags(wl_config: WorkloadConfig):

def _get_pathways_specific_flags(wl_config: WorkloadConfig) -> str:
"""Gets all pathways specific flags, including proxy, worker, and server."""
pw_config = wl_config.pathways_config
if pw_config is None:
return ''

colocated_python_sidecar_image_flag = (
f' --colocated-python-sidecar-image={pw_config.colocated_python_sidecar_image}'
f'--colocated-python-sidecar-image={pw_config.colocated_python_sidecar_image}'
if pw_config.colocated_python_sidecar_image is not None
else ''
)
server_image_flag = (
f' --server-image={pw_config.server_image}'
f'--server-image={pw_config.server_image}'
if pw_config.server_image is not None
else ''
)
proxy_server_image_flag = (
f' --proxy-server-image={pw_config.proxy_server_image}'
f'--proxy-server-image={pw_config.proxy_server_image}'
if pw_config.proxy_server_image is not None
else ''
)

proxy_flags = _get_pathways_proxy_flags(wl_config)
worker_flags = _get_pathways_worker_flags(wl_config)
server_flags = _get_pathways_server_flags(wl_config)

pathways_specific_flags = (
f' {server_image_flag} '
f' {proxy_server_image_flag} '
f' {colocated_python_sidecar_image_flag} '
f' --termination-grace-period-seconds=300 '
f' --pathways-gcs-location={wl_config.base_output_directory} '
f' --custom-pathways-server-args="{server_flags}" '
f' --custom-pathways-proxy-server-args="{proxy_flags}" '
f' --custom-pathways-worker-args="{worker_flags}" '
)
return pathways_specific_flags
# These calls now internally handle the headless mode logic
proxy_flags_str = _get_pathways_proxy_flags(wl_config)
worker_flags_str = _get_pathways_worker_flags(wl_config)
server_flags_str = _get_pathways_server_flags(wl_config)

# Construct the flags string carefully to avoid extra spaces
flags_parts = [
server_image_flag,
proxy_server_image_flag,
colocated_python_sidecar_image_flag,
'--termination-grace-period-seconds=300', # Static flag
f'--pathways-gcs-location={wl_config.base_output_directory}'
]
if server_flags_str:
flags_parts.append(f'--custom-pathways-server-args="{server_flags_str}"')
if proxy_flags_str:
flags_parts.append(f'--custom-pathways-proxy-server-args="{proxy_flags_str}"')
if worker_flags_str:
flags_parts.append(f'--custom-pathways-worker-args="{worker_flags_str}"')

pathways_specific_flags_str = ' '.join(flag for flag in flags_parts if flag)

if pw_config.headless:
pathways_specific_flags_str += ' --headless'

return pathways_specific_flags_str.strip()


def generate_xpk_workload_cmd(
Expand All @@ -560,29 +609,31 @@ def generate_xpk_workload_cmd(
"""Generates a command to run a maxtext model on XPK."""

is_pathways_enabled = wl_config.pathways_config is not None
is_pathways_headless_enabled = wl_config.pathways_config and wl_config.pathways_config.headless

time.localtime()
length_of_random_str = 3
temp_post_fix = ''.join(
random.choice(string.ascii_lowercase + string.digits) for _ in range(length_of_random_str)
)

truncate_model_name = 12
truncate_prefix = 5
common_post_fix = f"-{wl_config.num_slices}-{time.strftime('%m%d%H', time.localtime())}-{temp_post_fix}"
truncate_model_name = 10
truncate_prefix = 3
post_fix = f"-{wl_config.num_slices}-{time.strftime('%m%d%H', time.localtime())}-{temp_post_fix}"
common_prefix = os.environ['USER']
pw_prefix = "pw-"

if workload_name is None: # Generate name if not provided
if is_pathways_enabled:
post_fix = f"-{wl_config.num_slices}-{temp_post_fix}"
name = (
f"{pw_prefix}{wl_config.model.model_name.replace('_', '-')[:truncate_model_name - len(pw_prefix)]}"
)
else:
name = (
f"{wl_config.model.model_name.replace('_', '-')[:truncate_model_name]}"
)
name = f"{common_prefix[:truncate_prefix]}-{name}{common_post_fix}"
name = f"{common_prefix[:truncate_prefix]}-{name}{post_fix}"
else:
name = workload_name # Use provided name

Expand All @@ -591,10 +642,12 @@ def generate_xpk_workload_cmd(
wl_config.run_name,
'metrics')

user_command = build_user_command(
name=name,
wl_config=wl_config
)
user_command = ''
if not is_pathways_headless_enabled:
user_command = build_user_command(
name=name,
wl_config=wl_config
)

additional_flags = ''
if not is_pathways_enabled and wl_config.libtpu_type == LibTpuType.CUSTOM:
Expand All @@ -615,10 +668,10 @@ def generate_xpk_workload_cmd(
f'--docker-image={pw_config.runner_image}'
)
else:
docker_image_flag = f'--base-docker-image="{wl_config.base_docker_image}"'
docker_image_flag = f'--docker-image="{wl_config.base_docker_image}"'

upload_metrics_to_bq_cmd = ""
if wl_config.generate_metrics_and_upload_to_big_query:
if wl_config.generate_metrics_and_upload_to_big_query and not is_pathways_headless_enabled:
# TODO (optionally) make it so that this upload step is done on local device instead of within the workload.
args = _build_args_from_config(wl_config)
args_str = ""
Expand Down
Loading
Loading