From f171a3eddce114af03ab3fdf5ffcca63256e2c3f Mon Sep 17 00:00:00 2001 From: tfaod <8447104+tfaod@users.noreply.github.com> Date: Sat, 23 Aug 2025 19:17:19 +0000 Subject: [PATCH] Add ademamix with optimal hparams --- submissions/self_tuning/ademamix/__init__.py | 0 .../self_tuning/ademamix/submission.py | 378 ++++++++++++++++++ .../self_tuning/ademamix/submission_info.yml | 154 +++++++ 3 files changed, 532 insertions(+) create mode 100644 submissions/self_tuning/ademamix/__init__.py create mode 100644 submissions/self_tuning/ademamix/submission.py create mode 100644 submissions/self_tuning/ademamix/submission_info.yml diff --git a/submissions/self_tuning/ademamix/__init__.py b/submissions/self_tuning/ademamix/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py new file mode 100644 index 000000000..cdcafdc3f --- /dev/null +++ b/submissions/self_tuning/ademamix/submission.py @@ -0,0 +1,378 @@ +from __future__ import annotations + +import math +import collections +from typing import Any, Dict, Iterator, List, Optional, Tuple + +import torch +import torch.distributed.nn as dist_nn +from absl import logging +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR +from torch.optim.optimizer import Optimizer + +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup + +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR + +USE_PYTORCH_DDP = pytorch_setup()[0] + +# optimal parameters across all workloads +HPARAMS = { + "learning_rate": 2e-3, + "one_minus_beta1": 0.2, + "beta2": 0.995, + "beta3": 0.9995, + "weight_decay": 0.1, + "warmup_factor": 0.02, + "alpha": 8, + "beta3_warmup": 500e3, + "alpha_warmup": 500e3, + "grad_clip": 0.5, +} +HPARAMS = collections.namedtuple('Hyperparameters', HPARAMS.keys())(**HPARAMS) + + +def linear_warmup_scheduler(step, alpha_end, alpha_start=0, warmup=1): + if step < warmup: + a = step / float(warmup) + return (1.0-a) * alpha_start + a * alpha_end + return alpha_end + + +def linear_hl_warmup_scheduler(step, beta_end, beta_start=0, warmup=1): + + def f(beta, eps=1e-8): + return math.log(0.5)/math.log(beta+eps)-1 + + def f_inv(t): + return math.pow(0.5, 1/(t+1)) + + if step < warmup: + a = step / float(warmup) + return f_inv((1.0-a) * f(beta_start) + a * f(beta_end)) + return beta_end + + +class AdEMAMix(Optimizer): + r"""Implements the AdEMAMix algorithm. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999, 0.9999)) + corresponding to beta_1, beta_2, beta_3 in AdEMAMix + alpha (float): AdEMAMix alpha coeficient mixing the slow and fast EMAs (default: 2) + beta3_warmup (int, optional): number of warmup steps used to increase beta3 (default: None) + alpha_warmup: (int, optional): number of warmup steps used to increase alpha (default: None) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay as in AdamW (default: 0) + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999, 0.9999), alpha=2.0, + beta3_warmup=None, alpha_warmup=None, eps=1e-8, + weight_decay=0): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= betas[2] < 1.0: + raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= alpha: + raise ValueError("Invalid alpha value: {}".format(alpha)) + defaults = dict(lr=lr, betas=betas, eps=eps, alpha=alpha, beta3_warmup=beta3_warmup, + alpha_warmup=alpha_warmup, weight_decay=weight_decay) + super(AdEMAMix, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdEMAMix, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + + lr = group["lr"] + wd = group["weight_decay"] + eps = group["eps"] + beta1, beta2, beta3_final = group["betas"] + beta3_warmup = group["beta3_warmup"] + alpha_final = group["alpha"] + alpha_warmup = group["alpha_warmup"] + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError('AdEMAMix does not support sparse gradients.') + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + if beta1 != 0.0: # save memory in case beta1 is 0.0 + state['exp_avg_fast'] = torch.zeros_like(p, memory_format=torch.preserve_format) + else: + state['exp_avg_fast'] = None + state['exp_avg_slow'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + exp_avg_fast, exp_avg_slow, exp_avg_sq = state['exp_avg_fast'], state['exp_avg_slow'], state['exp_avg_sq'] + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + # Compute the effective alpha and beta3 in case warmup is used + if alpha_warmup is not None: + alpha = linear_warmup_scheduler(state["step"], alpha_end=alpha_final, alpha_start=0, warmup=alpha_warmup) + else: + alpha = alpha_final + + if beta3_warmup is not None: + beta3 = linear_hl_warmup_scheduler(state["step"], beta_end=beta3_final, beta_start=beta1, warmup=beta3_warmup) + else: + beta3 = beta3_final + + # Decay the first and second moment running average coefficient + if beta1 != 0.0: + exp_avg_fast.mul_(beta1).add_(grad, alpha=1 - beta1) + else: + exp_avg_fast = grad + exp_avg_slow.mul_(beta3).add_(grad, alpha=1 - beta3) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + update = (exp_avg_fast.div(bias_correction1) + alpha * exp_avg_slow) / denom + + # decay + update.add_(p, alpha=wd) + + p.add_(-lr * update) + + return loss + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a Lion optimizer and a learning rate schedule.""" + del model_state + del rng + del hyperparameters + + hyperparameters = HPARAMS + + optimizer_state = { + 'optimizer': + AdEMAMix( + model_params.parameters(), + lr=HPARAMS.learning_rate, + betas=(1.0 - HPARAMS.one_minus_beta1, + HPARAMS.beta2, + HPARAMS.beta3), + weight_decay=HPARAMS.weight_decay, + alpha=HPARAMS.alpha, + beta3_warmup= HPARAMS.beta3_warmup, + alpha_warmup=HPARAMS.alpha_warmup) + } + + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint, HPARAMS, optimizer_state['optimizer']) + optimizer_state['hyperparameters'] = hyperparameters + + return optimizer_state + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del train_state + del eval_results + del hyperparameters + + hyperparameters = HPARAMS + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + label_smoothing = ( + hyperparameters.label_smoothing if hasattr(HPARAMS, + 'label_smoothing') else 0.0) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, global_step) + logging.info('%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item()) + + return (optimizer_state, current_param_container, new_model_state) + + +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + + +def get_batch_size(workload_name): + # Return the global batch size. + if hasattr(HPARAMS, "batch_size"): + return HPARAMS.batch_size + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/submissions/self_tuning/ademamix/submission_info.yml b/submissions/self_tuning/ademamix/submission_info.yml new file mode 100644 index 000000000..05282b931 --- /dev/null +++ b/submissions/self_tuning/ademamix/submission_info.yml @@ -0,0 +1,154 @@ +name: algoperf +channels: + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - bzip2=1.0.8=h5eee18b_6 + - ca-certificates=2025.2.25=h06a4308_0 + - expat=2.7.1=h6a678d5_0 + - ld_impl_linux-64=2.40=h12ee557_0 + - libffi=3.4.4=h6a678d5_1 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libstdcxx-ng=11.2.0=h1234567_1 + - libuuid=1.41.5=h5eee18b_0 + - libxcb=1.17.0=h9b100fa_0 + - ncurses=6.5=h7934f7d_0 + - openssl=3.0.17=h5eee18b_0 + - pthread-stubs=0.3=h0ce48e5_1 + - python=3.11.13=h1a3bd86_0 + - readline=8.2=h5eee18b_0 + - setuptools=78.1.1=py311h06a4308_0 + - sqlite=3.50.2=hb25bd0a_1 + - tk=8.6.14=h993c535_1 + - wheel=0.45.1=py311h06a4308_0 + - xorg-libx11=1.8.12=h9b100fa_1 + - xorg-libxau=1.0.12=h9b100fa_0 + - xorg-libxdmcp=1.1.5=h9b100fa_0 + - xorg-xorgproto=2024.1=h5eee18b_1 + - xz=5.6.4=h5eee18b_1 + - zlib=1.2.13=h5eee18b_1 + - pip: + - absl-py==2.1.0 + - algoperf==0.5.1.dev494+g22b07c82f + - array-record==0.7.2 + - astunparse==1.6.3 + - attrs==25.3.0 + - certifi==2025.8.3 + - charset-normalizer==3.4.2 + - chex==0.1.86 + - click==8.2.1 + - cloudpickle==3.1.1 + - clu==0.0.12 + - contourpy==1.3.3 + - cycler==0.12.1 + - decorator==5.2.1 + - dm-tree==0.1.9 + - docker==7.1.0 + - docstring-parser==0.17.0 + - einops==0.8.1 + - etils==1.13.0 + - filelock==3.18.0 + - flatbuffers==25.2.10 + - flax==0.8.4 + - fonttools==4.59.0 + - fsspec==2025.7.0 + - gast==0.6.0 + - google-pasta==0.2.0 + - googleapis-common-protos==1.70.0 + - gputil==1.4.0 + - grpcio==1.74.0 + - h5py==3.12.0 + - humanize==4.12.3 + - idna==3.10 + - imageio==2.37.0 + - immutabledict==4.2.1 + - importlib-resources==6.5.2 + - jax==0.4.28 + - jaxlib==0.4.28 + - jinja2==3.1.6 + - joblib==1.5.1 + - jraph==0.0.6.dev0 + - keras==3.11.1 + - kiwisolver==1.4.8 + - lazy-loader==0.4 + - libclang==18.1.1 + - markdown==3.8.2 + - markdown-it-py==3.0.0 + - markupsafe==3.0.2 + - matplotlib==3.10.5 + - mdurl==0.1.2 + - ml-collections==1.1.0 + - ml-dtypes==0.4.1 + - mpmath==1.3.0 + - msgpack==1.1.1 + - namex==0.1.0 + - nest-asyncio==1.6.0 + - networkx==3.2.1 + - numpy==2.0.2 + - nvidia-cublas-cu12==12.4.5.8 + - nvidia-cuda-cupti-cu12==12.4.127 + - nvidia-cuda-nvrtc-cu12==12.4.127 + - nvidia-cuda-runtime-cu12==12.4.127 + - nvidia-cudnn-cu12==9.1.0.70 + - nvidia-cufft-cu12==11.2.1.3 + - nvidia-curand-cu12==10.3.5.147 + - nvidia-cusolver-cu12==11.6.1.9 + - nvidia-cusparse-cu12==12.3.1.170 + - nvidia-nccl-cu12==2.21.5 + - nvidia-nvjitlink-cu12==12.4.127 + - nvidia-nvtx-cu12==12.4.127 + - opt-einsum==3.4.0 + - optax==0.2.2 + - optree==0.17.0 + - orbax-checkpoint==0.6.4 + - packaging==25.0 + - pandas==2.3.1 + - pillow==11.3.0 + - pip==25.2 + - promise==2.3 + - protobuf==4.25.5 + - psutil==6.1.0 + - pyarrow==21.0.0 + - pydub==0.25.1 + - pygments==2.19.2 + - pyparsing==3.2.3 + - python-dateutil==2.9.0.post0 + - pytz==2025.2 + - pyyaml==6.0.2 + - requests==2.32.4 + - rich==14.1.0 + - ruff==0.12.7 + - scikit-image==0.24.0 + - scikit-learn==1.5.2 + - scipy==1.16.1 + - sentencepiece==0.2.0 + - simple-parsing==0.1.7 + - six==1.17.0 + - sympy==1.13.1 + - tabulate==0.9.0 + - tensorboard==2.18.0 + - tensorboard-data-server==0.7.2 + - tensorflow==2.18.0 + - tensorflow-datasets==4.9.7 + - tensorflow-io-gcs-filesystem==0.37.1 + - tensorflow-metadata==1.17.2 + - tensorflow-probability==0.20.0 + - tensorflow-text==2.18.0 + - tensorstore==0.1.74 + - termcolor==3.1.0 + - threadpoolctl==3.6.0 + - tifffile==2025.6.11 + - toml==0.10.2 + - toolz==1.0.0 + - torch==2.5.1 + - torchvision==0.20.1 + - tqdm==4.67.1 + - triton==3.1.0 + - typing-extensions==4.14.1 + - tzdata==2025.2 + - urllib3==2.5.0 + - werkzeug==3.1.3 + - wrapt==1.17.2 + - zipp==3.23.0 \ No newline at end of file