diff --git a/examples/torchtune/README.md b/examples/torchtune/README.md new file mode 100644 index 0000000000..10f634dd7c --- /dev/null +++ b/examples/torchtune/README.md @@ -0,0 +1,27 @@ +# torchtune Examples +Examples to tune language models using [torchtune](https://github.com/pytorch/torchtune). + +## Setup +1. Follow the [torchao Installation](../../README.md#installation) steps. + +2. Install `torchtune`: +``` +pip install torchtune +``` + +## Run +1. Download a model (see more details [here](https://github.com/pytorch/torchtune#downloading-a-model)): +``` +tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth" +``` + +2. Finetune: +- To finetune on a single device: +``` +tune run --nproc_per_node 1 full_finetune_single_device.py --config ./configs/full_finetune.yaml +``` + +- To finetune on multiple GPUs: +``` +tune run --nproc_per_node 8 full_finetune_distributed.py --config ./configs/full_finetune.yaml +``` \ No newline at end of file diff --git a/examples/torchtune/configs/full_finetune.yaml b/examples/torchtune/configs/full_finetune.yaml new file mode 100644 index 0000000000..2bc896e77d --- /dev/null +++ b/examples/torchtune/configs/full_finetune.yaml @@ -0,0 +1,104 @@ +# Config for multi-device full finetuning in full_finetune_distributed.py +# using a Llama2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama2/7B_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama2/7B_full checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# Single device full finetuning requires more memory optimizations. It's +# best to use 7B_full_single_device.yaml for those cases + + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Llama-3.2-1B/original/tokenizer.model + max_seq_len: null + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.llama3_2.llama3_2_1b + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-3.2-1B-Instruct + checkpoint_files: [ + model.safetensors + ] + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: LLAMA3 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 32 +epochs: 1 +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 2e-5 +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /fsx-atom/melhoushi/torchtune/llama3.2_1b_superblock/ +log_every_n_steps: 1 +log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/examples/torchtune/full_finetune_distributed.py b/examples/torchtune/full_finetune_distributed.py new file mode 100644 index 0000000000..5505116dc0 --- /dev/null +++ b/examples/torchtune/full_finetune_distributed.py @@ -0,0 +1,936 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import time + +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Union +from warnings import warn + +import torch +from omegaconf import DictConfig, ListConfig + +from torch import nn +from torch.distributed import destroy_process_group, init_process_group + +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler +from torchtune import config, modules, training, utils +from torchtune.config._utils import _get_component_from_path +from torchtune.data import padded_collate_packed +from torchtune.datasets import ConcatDataset +from torchtune.recipe_interfaces import FTRecipeInterface +from torchtune.training import DummyProfiler, PROFILER_KEY +from torchtune.training.activations import apply_selective_activation_checkpointing +from torchtune.training.lr_schedulers import get_lr + +from tqdm import tqdm + +log = utils.get_logger("DEBUG") + + +class FullFinetuneRecipeDistributed(FTRecipeInterface): + """ + Full finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe supports + distributed training and can be run on a single node (1 to 8 GPUs). + + Features: + - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states + is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is + done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config + ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy). + DDP is currently not supported. Training on CPU is not supported. + + - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing`` + flag. Activation checkpointing helps reduce the memory footprint since we no longer keep + activations in memory and instead recompute them during the backward pass. This is especially + helpful for larger batch sizes when you're memory constrained. But these savings in memory + come at the cost of training performance. In most cases training can slow-down quite a bit as + a result of this activation recomputation. + + - Activation Offloading. This can be controlled using the ``enable_activation_offloading`` + flag. Activation offloading is a technique similar to activations checkpointing that helps + reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations + checkpointing drops the activation in the forward to recompute it later in the backward, + activations offloading will drop the activation in the forward to the CPU and bring it + back during the backward pass. As always, there is a tradeoff--these savings in memory can + come at the cost of training performance and CPU resources. To recover some runtime cost, + we've added an option to enable offloading on a different stream to permit overlapping with + the computation. This option is currently only available on PyTorch 2.5 or later and will + be enabled by default if an acceptable torch version is found. Activation offloading can be + used in conjunction with activation checkpointing. + + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` + flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In + most cases this should halve the memory footprint of full precision (fp32) training, without + loss in model quality (will depend on the model, training data and other settings). For + GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 + precision are currently not supported. + + - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is + controlled using the ``gradient_accumulation_steps`` flag. + + Total Batch Size = batch_size * number of GPUs * gradient accumulation steps. + + For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a + total batch size of 64. + + Gradient accumulation is especially useful when you are memory constrained. In this case, + accumulating gradients might give you better training speed than enabling activation + checkpointing. + + - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of + training. Optimizer state and recipe state (seed, total_epochs, number of epochs run etc) are + only saved at the end of a given epoch and used in case of resuming training. + + Resuming training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is + currently not supported. + + For more details on the checkpointer, please take a look at + our checkpointer deepdive (https://pytorch.org/torchtune/main/deep_dives/checkpointer.html). + + - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + + - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default, + ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set + ``clip_grad_norm='inf'``. + + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config + has example commands for how to kick-off training. + + Args: + cfg (DictConfig): OmegaConf object parsed from yaml file + + Raises: + ValueError: If ``dtype`` is set to fp16. + RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + RuntimeError: If ``left_pad_sequence`` is set as the data collator. + RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. + RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. + """ + + def __init__(self, cfg: DictConfig) -> None: + self._device = utils.get_device(device=cfg.device) + self._dtype = training.get_dtype(cfg.dtype, device=self._device) + + if self._dtype == torch.float16: + raise ValueError( + "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." + ) + + if ( + cfg.get("fsdp_cpu_offload", False) + and cfg.optimizer.get("fused", False) + and not utils.torch_version_ge("2.4.0") + ): + raise RuntimeError( + "Using fused optimizer on CPU is only supported in PyTorch nightly." + ) + + # logging attributes + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.get("log_every_n_steps", 1) + self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + + if self._log_peak_memory_stats and self._device.type != "cuda": + log.info( + "log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False." + ) + self._log_peak_memory_stats = False + + # _is_rank_zero is used primarily for logging. In the future, the logger + # should directly take care of this + _, rank = training.get_world_size_and_rank() + self._is_rank_zero = rank == 0 + + # Training cfg + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False) + self._clip_grad_norm = cfg.get("clip_grad_norm", None) + + # Optimizer in backward is not compatible with gradient accumulation or gradient clipping + if self._optimizer_in_bwd: + if self._clip_grad_norm is not None: + raise RuntimeError( + "Gradient clipping is not supported with optimizer in bwd." + "Please set clip_grad_norm=None, or optimizer_in_bwd=False." + ) + if self._gradient_accumulation_steps > 1: + raise RuntimeError( + "Gradient accumulation is not supported with optimizer in bwd." + "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False." + ) + + # activation checkpointing/offloading + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) + self._enable_activation_offloading = cfg.get( + "enable_activation_offloading", False + ) + if self._enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True when training on CUDA" + ) + if not self._enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif ( + self._enable_activation_checkpointing + and cfg.checkpointer.model_type != "LLAMA3_VISION" + ): + utils.log_rank_zero( + log, + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further.", + ) + + # These are public properties which are updated by the checkpoint loader + # when ``resume_from_checkpoint`` is `True` or validated in tests + self.seed = training.set_seed(seed=cfg.seed) + self.epochs_run = 0 + self.total_epochs = cfg.epochs + self.max_steps_per_epoch = cfg.max_steps_per_epoch + self.global_step = 0 + + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the checkpoint state from file and validate. If resume_from_checkpoint + is True, this also includes the recipe state. + """ + self._checkpointer = config.instantiate( + cfg_checkpointer, + resume_from_checkpoint=self._resume_from_checkpoint, + ) + checkpoint_dict = self._checkpointer.load_checkpoint() + + if self._resume_from_checkpoint: + self._update_recipe_state(checkpoint_dict) + return checkpoint_dict + + def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + try: + self.epochs_run = ckpt_dict[training.EPOCHS_KEY] + + # on mismatch, warn the user and prevent the override + if self.seed != ckpt_dict[training.SEED_KEY]: + warn( + message=( + "Config value for seed does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}" + ) + ) + self.seed = ckpt_dict[training.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]: + warn( + message=( + "Config value for max_steps_per_epoch does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}" + ) + ) + self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY] + + # on mismatch, warn the user but allow the override + if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: + warn( + message=( + "Config value for total_epochs does not match the checkpoint value, " + f"using the config value: {self.total_epochs}" + ) + ) + + except KeyError as e: + raise KeyError( + "Checkpoint does not contain the required keys needed for updating recipe state. " + "Are you sure you passed in the right recipe checkpoint?" + ) from e + + def setup(self, cfg: DictConfig) -> None: + """ + Setup the recipe. This includes training state (if resume_from_checkpoint is True), + model, tokenizer, loss, optimizer, sampler, and dataloader. + """ + if self._is_rank_zero: + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) + + checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + + self._compile = cfg.get("compile", False) + self._model = self._setup_model( + cfg_model=cfg.model, + enable_activation_checkpointing=self._enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, + custom_sharded_layers=cfg.get("custom_sharded_layers", None), + fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), + reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), + model_state_dict=checkpoint_dict[training.MODEL_KEY], + ac_mode=cfg.get("ac_mode", None), + ac_option=cfg.get("ac_option", None), + ) + self._tokenizer = config.instantiate(cfg.tokenizer) + + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + optimizer_in_bwd=self._optimizer_in_bwd, + opt_state_dict=( + checkpoint_dict[training.OPT_KEY] + if self._resume_from_checkpoint + else None + ), + ) + + # initialize loss + self._loss_fn = config.instantiate(cfg.loss) + + if self._compile: + training.compile_loss(self._loss_fn, verbose=self._is_rank_zero) + + if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": + # set num_output_chunks for model + self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) + + if self._is_rank_zero: + log.info("Loss is initialized.") + + # sampler and dataloader depend on the tokenizer and loss_fn and should be + # setup after both of these are initialized + collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft") + self._sampler, self._dataloader = self._setup_data( + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + collate_fn=collate_name, + ) + + # Finally update the recipe state which can only be correctly set after all of the + # other components have been initialized and updated. + # + # Number of training steps in each epoch depends on the number of batches produced + # by the dataloader, the max_steps_per_epoch param set by the user and the + # gradient_accumulation_steps param. This value is used for logging and tracking + # training state. The computation should happen after the dataloader has been setup + self._steps_per_epoch = ( + len(self._dataloader) // self._gradient_accumulation_steps + ) + if ( + self.max_steps_per_epoch is not None + and self.max_steps_per_epoch < self._steps_per_epoch + ): + self._steps_per_epoch = self.max_steps_per_epoch + self.global_step = self.epochs_run * self._steps_per_epoch + + # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) + # if cfg is missing profiler key or if `cfg.profiler.enabled = False` + self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + + # Used to ignore labels for loss computation + self.ignore_labels_cache = torch.full( + (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device + ) + + def _setup_profiler( + self, cfg_profiler: Optional[DictConfig] = None + ) -> Union[torch.profiler.profile, DummyProfiler]: + """ + Parses the `profiler` section of top-level `cfg` and sets up profiler + + Args: + cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to + `recipe.main`). Default None. + + Returns: + profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods + for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such + that the instrumented training loop does not need to be changed profiling is disabled. + + The profiler config can be provided in configs under the `profiler` key with the following layout: + + .. code-block:: yaml + profiler: + enabled: bool + + #Output directory of trace artifacts + output_dir: str + + #`torch.profiler.ProfilerActivity` types to trace + cpu: bool + cuda: bool + + #Trace options + profile_memory: bool + with_stack: bool + record_shapes: bool + with_flops: bool + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: int + warmup_steps: int + active_steps: int + num_cycles: int + """ + # Missing profiler section in config, assume disabled + if cfg_profiler is None: + cfg_profiler = DictConfig({"enabled": False}) + + # Check that component is included and set correctly + if cfg_profiler.get("_component_", None) is None: + cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" + else: + assert ( + cfg_profiler.get("_component_") + == "torchtune.training.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" + + profiler, profiler_cfg = config.instantiate(cfg_profiler) + + if self._is_rank_zero: + log.info(f" Profiler config after instantiation: {profiler_cfg}") + + self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) + if profiler_cfg["enabled"]: + self.profiler_wait_steps = profiler_cfg["wait_steps"] + self.profiler_warmup_steps = profiler_cfg["warmup_steps"] + self.profiler_active_steps = profiler_cfg["active_steps"] + + return profiler + + def _setup_model( + self, + cfg_model: DictConfig, + enable_activation_checkpointing: bool, + enable_activation_offloading: bool, + fsdp_cpu_offload: bool, + reshard_after_forward: bool, + model_state_dict: Dict[str, Any], + custom_sharded_layers: Optional[List[str]] = None, + ac_mode: Optional[str] = None, + ac_option: Optional[int] = None, + ) -> nn.Module: + """ + Model initialization has some important considerations: + a. To minimize GPU peak memory, we initialize the model on meta device with + the right dtype + b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since + full state dicts are loaded with ``torch.load(mmap=True)`` + """ + + if self._is_rank_zero: + log.info( + "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..." + ) + init_start = time.perf_counter() + + # with training.set_default_dtype(self._dtype), torch.device("meta"): + with training.set_default_dtype(self._dtype): + model = config.instantiate(cfg_model) + + # Apply Sparsity + if True: + from torchao.sparsity.prototype.superblock.utils import ( + accelerate_with_sparsity, + get_args_parser, + simulate_sparsity, + ) + + superblock_args = get_args_parser(benchmark=True).parse_args([]) + superblock_args.sparsity = "bsr" + superblock_args.sparsity_linear = 0.9 + superblock_args.bsr = 64 + superblock_args.skip_attention_proj = True + simulate_sparsity(model, superblock_args) + + if self._compile: + training.compile_model(model, verbose=self._is_rank_zero) + + # We currently have two versions of activation checkpointing in this recipe + # for testing and BC purposes. ``enable_activation_checkpointing`` controls + # the older version of AC and this behavior is unchanged + # ac_mode and ac_option together control selective AC. This is only enabled + # when these are set AND ``enable_activation_checkpointing`` is set to False + # We'll clean this up as soon as testing of AC is complete + if (not enable_activation_checkpointing) and (ac_mode is not None): + apply_selective_activation_checkpointing( + model, + ac_mode, + ac_option, + ) + + # original activation checkpointing (full) - flip the condition above + if enable_activation_checkpointing and ac_mode is None: + training.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} + ) + + # For FSDP sharding + fsdp_shard_conditions = [ + partial( + training.get_shard_conditions, + names_to_match=custom_sharded_layers, + ) + ] + training.shard_model( + model=model, + shard_conditions=fsdp_shard_conditions, + cpu_offload=fsdp_cpu_offload, + reshard_after_forward=reshard_after_forward, + ) + + with training.set_default_dtype(self._dtype), self._device: + for m in model.modules(): + # RoPE is not covered in state dict + if hasattr(m, "rope_init"): + m.rope_init() + + # This method will convert the full model state dict into a sharded state + # dict and load into the model + training.load_from_full_model_state_dict( + model, + model_state_dict, + self._device, + self._is_rank_zero, + # strict=True, + strict=False, + cpu_offload=fsdp_cpu_offload, + ) + + # activation offloading + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading + ) + + # Ensure no params and buffers are on meta device + training.validate_no_params_on_meta_device(model) + + if self._is_rank_zero: + log.info( + f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" + ) + memory_stats = training.get_memory_stats(device=self._device) + training.log_memory_stats(memory_stats) + + # synchronize before training begins + torch.distributed.barrier() + + return model + + def _setup_optimizer( + self, + cfg_optimizer: DictConfig, + optimizer_in_bwd: bool = False, + opt_state_dict: Optional[Dict[str, Any]] = None, + ) -> Optional[Optimizer]: + if optimizer_in_bwd: + # Maintain a dict of optims for every parameter. + optim_dict = { + param: config.instantiate(cfg_optimizer, [param]) + for param in self._model.parameters() + } + + # Register optimizer step hooks on the model to run optimizer in backward. + training.register_optim_in_bwd_hooks( + model=self._model, optim_dict=optim_dict + ) + # Create a wrapper for checkpoint save/load of optimizer states when running in backward. + self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper( + model=self._model, optim_dict=optim_dict + ) + # Load optimizer states for each param. If optimizer states are being restored in an optimizer in + # backward run, these need to have been saved with the same setting. Cannot restore from runs that + # did not use optimizer in backward. + if opt_state_dict is not None: + for param in opt_state_dict.keys(): + try: + training.load_from_full_optimizer_state_dict( + self._optim_ckpt_wrapper.state_dict()[param], + opt_state_dict[param], + self._device, + ) + except BaseException as e: + raise RuntimeError( + "Failed loading in-backward optimizer checkpoints." + "Please make sure run being restored from was using in-backward optimizer." + ) from e + if self._is_rank_zero: + log.info("In-backward optimizers are set up.") + return None + else: + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + if opt_state_dict: + training.load_from_full_optimizer_state_dict( + optimizer, + opt_state_dict, + self._device, + ) + + if self._is_rank_zero: + log.info("Optimizer is initialized.") + return optimizer + + def _setup_data( + self, + cfg_dataset: DictConfig, + shuffle: bool, + batch_size: int, + collate_fn: str, + ) -> Tuple[DistributedSampler, DataLoader]: + """ + All data related setup happens here. Currently this recipe only supports the + DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, + iterable datasets and streaming datasets are not supported. + """ + world_size, rank = training.get_world_size_and_rank() + + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + packed = False + else: + ds = config.instantiate(cfg_dataset, self._tokenizer) + packed = cfg_dataset.get("packed", False) + + # Instantiate collate_fn + if "left_pad_sequence" in collate_fn: + raise RuntimeError("left_pad_sequence collator is only for inference.") + collate_fn = _get_component_from_path(collate_fn) + + sampler = DistributedSampler( + ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0 + ) + dataloader = DataLoader( + dataset=ds, + batch_size=batch_size, + sampler=sampler, + # dropping last avoids shape issues with compile + flex attention + drop_last=True, + collate_fn=( + partial( + collate_fn, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else padded_collate_packed + ), + ) + + if self._is_rank_zero: + log.info("Dataset and Sampler are initialized.") + + return sampler, dataloader + + def save_checkpoint( + self, + epoch: int, + ) -> None: + """ + Checkpoint the state of the recipe. The constructed checkpoint state dict + contains the following information: + - Model weights with key training.MODEL_KEY + - Relevant recipe state if training is not complete + + Checkpointer will save the model weights and recipe state in + different checkpoint files. To correctly resume training from an intermediate checkpoint, + the model weights and recipe state must be provided. + """ + # final dict passed onto the checkpointer + checkpoint_dict = {} + + intermediate_checkpoint = epoch + 1 < self.total_epochs + + if self._is_rank_zero: + log.info( + "Saving checkpoint. This may take some time. Retrieving full model state dict..." + ) + start = time.perf_counter() + + # To prevent GPU memory from spiking during checkpoint save, + # we consolidate the full model and optim state dicts on CPU for rank 0 + cpu_state_dict = training.gather_cpu_state_dict( + self._model.state_dict(), + self._is_rank_zero, + device=self._device, + ) + + if self._is_rank_zero: + log.info( + f"Getting full model state dict took {time.perf_counter() - start:.2f} secs" + ) + + if intermediate_checkpoint: + start = time.perf_counter() + if self._is_rank_zero: + log.info("Getting optimizer state dict...") + if not self._optimizer_in_bwd: + opt_state_dict = training.get_full_optimizer_state_dict( + self._optimizer, + self._is_rank_zero, + device=self._device, + ) + else: + opt_state_dict = {} + for param, opt in self._optim_ckpt_wrapper.optim_map.items(): + opt_state_dict[param] = training.get_full_optimizer_state_dict( + opt, self._is_rank_zero, device=self._device + ) + if self._is_rank_zero: + log.info( + f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs" + ) + else: + opt_state_dict = None + + # Now that we have the model and opt state dict, create the actual checkpoint dict + # to be sent to the checkpointer and ultimately written to file + + if self._is_rank_zero: + start = time.perf_counter() + checkpoint_dict.update({training.MODEL_KEY: cpu_state_dict}) + + # if training is in-progress, checkpoint the optimizer state and recipe state + # as well. + if intermediate_checkpoint: + checkpoint_dict.update( + { + training.OPT_KEY: opt_state_dict, + training.SEED_KEY: self.seed, + training.EPOCHS_KEY: self.epochs_run, + training.TOTAL_EPOCHS_KEY: self.total_epochs, + training.MAX_STEPS_KEY: self.max_steps_per_epoch, + } + ) + + torch.save(checkpoint_dict, f"{self._output_dir}/checkpoint_{epoch}.pth") + # self._checkpointer.save_checkpoint( + # checkpoint_dict, + # epoch=epoch, + # intermediate_checkpoint=intermediate_checkpoint, + # ) + log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs") + + torch.distributed.barrier() + + def train(self) -> None: + """ + The core training loop. + """ + # clean up before training begins + training.cleanup_before_training() + + world_size, rank = training.get_world_size_and_rank() + + # zero out the gradients before starting training + if not self._optimizer_in_bwd: + self._optimizer.zero_grad() + else: + for opt in self._optim_ckpt_wrapper.optim_map.values(): + opt.zero_grad() + + # Initialize tokens count and running loss (for grad accumulation) + t0 = time.perf_counter() + running_loss = 0 + num_tokens = 0 + + self._profiler.start() + # self.epochs_run should be non-zero when we're resuming from a checkpoint + for curr_epoch in range(self.epochs_run, self.total_epochs): + # Update the sampler to ensure data is correctly shuffled across epochs + # in case shuffle is True + self._sampler.set_epoch(curr_epoch) + + pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0)) + for idx, batch in enumerate(self._dataloader): + if ( + self.max_steps_per_epoch is not None + and (idx // self._gradient_accumulation_steps) + == self.max_steps_per_epoch + ): + break + + # Start tracking CUDA memory for active steps for just the first epoch + if ( + self._is_rank_zero + and curr_epoch == 0 + and self.profiler_profile_memory + and idx == self.profiler_wait_steps + self.profiler_warmup_steps + ): + torch.cuda.memory._record_memory_history() + + utils.batch_to_device(batch, self._device) + + # Calculate the number of unmasked tokens in the current batch + # and increment the total number of tokens seen in the step + current_num_tokens = ( + batch["labels"] != self._loss_fn.ignore_index + ).sum() + num_tokens += current_num_tokens + + # Shape [b, s], needed for the loss not the model + labels = batch.pop("labels") + + with self.activations_handling_ctx: + logits = self._model(**batch) + + # Shift labels to compute loss + # equivalent to doing labels[..., 1:] and logits[..., :-1, :] + # But this way we dont need to slice the logits. We just add an ignore index to labels. + labels = torch.hstack( + (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) + ) + if not isinstance(logits, list): + labels = labels.reshape(-1) + logits = logits.reshape(-1, logits.size(-1)) + + # Compute loss + # Loss is normalized by default so we multiply by the number of tokens + # This way we can normalize by the total number of tokens if we're accumulating gradients + current_loss = self._loss_fn(logits, labels) * current_num_tokens + + # free logits otherwise it peaks backward memory + del logits + + running_loss += current_loss + + # For optimizer in backward, we need to normalize before calling backward + # This case and gradient accumulation are mutually exclusive + if self._optimizer_in_bwd: + torch.distributed.all_reduce(num_tokens) + torch.distributed.all_reduce(running_loss) + current_loss = current_loss / num_tokens + + current_loss.backward() + + # Step with optimizer + if (idx + 1) % self._gradient_accumulation_steps == 0: + if not self._optimizer_in_bwd: + # Get total number of tokens across all ranks to normalize gradients + torch.distributed.all_reduce(num_tokens) + # This will ensure that the logged loss matches what we're optimizing + torch.distributed.all_reduce(running_loss) + # Manually scale the gradients from unnormalized loss by total # of tokens + training.scale_grads(self._model, 1 / num_tokens) + if self._clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_( + self._model.parameters(), + max_norm=float(self._clip_grad_norm), + ) + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + + # Update the number of steps when the weights are updated + self.global_step += 1 + + loss_to_log = running_loss.item() / num_tokens + pbar.update(1) + pbar.set_description( + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + ) + + # Log per-step metrics + if ( + self.global_step % self._log_every_n_steps == 0 + and self._is_rank_zero + ): + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + "lr": get_lr( + ( + self._optimizer + if not self._optimizer_in_bwd + else self._optim_ckpt_wrapper + ), + ), + "tokens_per_second_per_gpu": num_tokens + / (time_per_step * world_size), + } + if self._log_peak_memory_stats: + log_dict.update( + training.get_memory_stats(device=self._device) + ) + if self._clip_grad_norm is not None: + log_dict.update({"grad_norm": grad_norm}) + self._metric_logger.log_dict( + log_dict, + step=self.global_step, + ) + + # Reset running stats for the next step + running_loss = 0 + num_tokens = 0 + t0 = time.perf_counter() + + # Stop tracking CUDA memory now that active steps are complete + if ( + self._is_rank_zero + and curr_epoch == 0 + and self.profiler_profile_memory + and idx + == self.profiler_wait_steps + + self.profiler_warmup_steps + + self.profiler_active_steps + ): + torch.cuda.memory._record_memory_history(enabled=None) + + # Step profiler + # Note that this is called within gradient accumulation block, hence + # will include multiple forward / backward passes if gradient accumulation > 1 + self._profiler.step() + + self.epochs_run += 1 + self.save_checkpoint(epoch=curr_epoch) + + self._profiler.stop() + + def cleanup(self) -> None: + if self._is_rank_zero: + self._metric_logger.close() + destroy_process_group() + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in config (see available configs through ``tune ls``) + - Overwritten by arguments from the command-line + """ + if not training.is_distributed(): + raise RuntimeError( + "Distributed finetune recipe should be run via a distributed launcher." + "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" + ) + init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") + if cfg.get("fsdp_cpu_offload", False): + # Utilize all available CPU cores for intra-op parallelism. This provides ~2x + # speed up when benchmarking fused AdamW on CPU + training.set_torch_num_threads() + + config.log_config(recipe_name="FullFinetuneRecipeDistributed", cfg=cfg) + + recipe = FullFinetuneRecipeDistributed(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/examples/torchtune/full_finetune_single_device.py b/examples/torchtune/full_finetune_single_device.py new file mode 100644 index 0000000000..3a8ed14d11 --- /dev/null +++ b/examples/torchtune/full_finetune_single_device.py @@ -0,0 +1,817 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import time +from functools import partial +from typing import Any, Dict, Optional, Tuple, Union +from warnings import warn + +import torch +from omegaconf import DictConfig, ListConfig + +from torch import nn +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler + +from torchtune import config, modules, training, utils +from torchtune.config._utils import _get_component_from_path +from torchtune.data import padded_collate_packed +from torchtune.datasets import ConcatDataset +from torchtune.recipe_interfaces import FTRecipeInterface +from torchtune.training import DummyProfiler, PROFILER_KEY +from torchtune.training.lr_schedulers import get_lr + +from tqdm import tqdm + + +log = utils.get_logger("DEBUG") + + +class FullFinetuneRecipeSingleDevice(FTRecipeInterface): + """ + Full finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe is optimized + for single GPU training. Training on CPU is not supported. + + Features: + - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing`` + flag. Activation checkpointing helps reduce the memory footprint since we no longer keep + activations in memory and instead recompute them during the backward pass. This is especially + helpful for larger batch sizes when you're memory constrained. But these savings in memory + come at the cost of training performance. In most cases training can slow-down quite a bit as + a result of this activation recomputation. + + - Activation Offloading. This can be controlled using the ``enable_activation_offloading`` + flag. Activation offloading is a technique similar to activations checkpointing that helps + reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations + checkpointing drops the activation in the forward to recompute it later in the backward, + activations offloading will drop the activation in the forward to the CPU and bring it + back during the backward pass. As always, there is a tradeoff--these savings in memory can + come at the cost of training performance and CPU resources. To recover some runtime cost, + we've added an option to enable offloading on a different stream to permit overlapping with + the computation. This option is currently only available on PyTorch 2.5 or later and will + be enabled by default if an acceptable torch version is found. Activation offloading can be + used in conjunction with activation checkpointing. + + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` + flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In + most cases this should halve the memory footprint of full precision (fp32) training, without + loss in model quality (will depend on the model, training data and other settings). For + GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 + precision are currently not supported. + + - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is + controlled using the ``gradient_accumulation_steps`` flag. + + Total Batch Size = batch_size * gradient accumulation steps. + + For example: with batch_size=1 and gradient_accumulation_steps=32 we get a total batch size of 32. + + Gradient accumulation is especially useful when you are memory constrained. In this case, + accumulating gradients might give you better training speed than enabling activation + checkpointing. + + - Optimizer in Backward. Fusing the optimizer step into the backward pass helps reduce the memory + footprint associated with gradients. This can be especially helpful when you are memory + constrained. Note that users can only use ONE of gradient accumulation or optimizer in backward. + These features currently do not work together. For more details on optimizer in backward, please + see this tutorial: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + + - Lower precision optimizers. This recipe supports lower-precision optimizers from the bitsandbytes + library (https://huggingface.co/docs/bitsandbytes/main/en/index). We've tested the recipe with + 8-bit AdamW and Paged AdamW. These optimizers are especially helpful when you are memory constrained + since they help reduce the memory footprint associated with the optimizer states. + + - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of + training. Optimizer State and recipe state (seed, total_epochs, number of epochs run etc) are + only saved at the end of a given epoch and used in case of resuming training. + + Resuming training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is + currently not supported. + + For more details on the checkpointer, please take a look at + our checkpointer deepdive (https://pytorch.org/torchtune/main/deep_dives/checkpointer.html). + + - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + + - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default, + ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set + ``clip_grad_norm='inf'``. + + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config + has example commands for how to kick-off training. + + Args: + cfg (DictConfig): OmegaConf object parsed from yaml file + + Raises: + ValueError: If ``dtype`` is set to fp16. + RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + RuntimeError: If ``gradient_accumulation_steps > 1`` and ``optimizer_in_bwd`` is `True`. + RuntimeError: If ``left_pad_sequence`` is set as the data collator. + RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. + RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. + """ + + def __init__(self, cfg: DictConfig) -> None: + self._device = utils.get_device(device=cfg.device) + self._dtype = training.get_dtype(cfg.dtype, device=self._device) + # Disable for fp16, as we haven't validated "full" fp16 with this recipe, nor + # enabled necessary features such as gradient scaling. + if self._dtype == torch.float16: + raise ValueError( + "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." + ) + + # logging attributes + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.get("log_every_n_steps", 1) + self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + + if self._log_peak_memory_stats and self._device.type != "cuda": + log.info( + "log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False." + ) + self._log_peak_memory_stats = False + + # Training cfg + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + self._optimizer_in_bwd = cfg.optimizer_in_bwd + self._clip_grad_norm = cfg.get("clip_grad_norm", None) + + # Optimizer in backward is not compatible with gradient accumulation or gradient clipping + if self._optimizer_in_bwd: + if self._clip_grad_norm is not None: + raise RuntimeError( + "Gradient clipping is not supported with optimizer in bwd." + "Please set clip_grad_norm=None, or optimizer_in_bwd=False." + ) + if self._gradient_accumulation_steps > 1: + raise RuntimeError( + "Gradient accumulation is not supported with optimizer in bwd." + "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False." + ) + + # activation checkpointing/offloading + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) + self._enable_activation_offloading = cfg.get( + "enable_activation_offloading", False + ) + if self._enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True when training on CUDA" + ) + if not self._enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif ( + self._enable_activation_checkpointing + and cfg.checkpointer.model_type != "LLAMA3_VISION" + ): + utils.log_rank_zero( + log, + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further.", + ) + + # These are public properties which are updated by the checkpoint loader + # when ``resume_from_checkpoint`` is `True` or validated in tests + self.seed = training.set_seed(seed=cfg.seed) + self.epochs_run = 0 + self.total_epochs = cfg.epochs + self.max_steps_per_epoch = cfg.max_steps_per_epoch + self.global_step = 0 + + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the checkpoint state from file and validate. If resume_from_checkpoint + is True, this also includes the recipe state. + """ + self._checkpointer = config.instantiate( + cfg_checkpointer, + resume_from_checkpoint=self._resume_from_checkpoint, + ) + checkpoint_dict = self._checkpointer.load_checkpoint() + + if self._resume_from_checkpoint: + self._update_recipe_state(checkpoint_dict) + return checkpoint_dict + + def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + try: + self.epochs_run = ckpt_dict[training.EPOCHS_KEY] + + # on mismatch, warn the user and prevent the override + if self.seed != ckpt_dict[training.SEED_KEY]: + warn( + message=( + "Config value for seed does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}" + ) + ) + self.seed = ckpt_dict[training.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]: + warn( + message=( + "Config value for max_steps_per_epoch does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}" + ) + ) + self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY] + + # on mismatch, warn the user but allow the override + if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: + warn( + message=( + "Config value for total_epochs does not match the checkpoint value, " + f"using the config value: {self.total_epochs}" + ) + ) + + except KeyError as e: + raise KeyError( + "Checkpoint does not contain the required keys needed for updating recipe state. " + "Are you sure you passed in the right recipe checkpoint?" + ) from e + + def setup(self, cfg: DictConfig) -> None: + """ + Sets up the recipe state correctly. This includes setting recipe attributes based + on the ``resume_from_checkpoint`` flag. + """ + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) + + ckpt_dict = self.load_checkpoint(cfg.checkpointer) + + # ``_setup_model`` handles initialization and loading the state dict. This method + # should be called before ``_setup_optimizer`` since transforming the optimizer + # state dict requires the model + self._compile = cfg.compile + if cfg.device == "npu" and cfg.compile: + raise ValueError( + "NPU does not support model compilation. Please set `compile: False` in the config." + ) + self._model = self._setup_model( + cfg_model=cfg.model, + enable_activation_checkpointing=self._enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, + compile_model=self._compile, + model_state_dict=ckpt_dict[training.MODEL_KEY], + ) + self._tokenizer = config.instantiate(cfg.tokenizer) + log.info("Tokenizer is initialized from file.") + + # _setup_optimizer should take in ckpt_dict only if training is resumed from + # checkpoint. Transforming the opt state dict is handled by this method + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + optimizer_in_bwd=cfg.optimizer_in_bwd, + opt_state_dict=( + ckpt_dict[training.OPT_KEY] if self._resume_from_checkpoint else None + ), + ) + + # initialize loss + self._loss_fn = config.instantiate(cfg.loss) + + if self._compile: + training.compile_loss(self._loss_fn) + + if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": + # set num_output_chunks for model + self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) + + log.info("Loss is initialized.") + + # sampler and dataloader depend on the tokenizer and loss_fn and should be + # setup after both of these are initialized + collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft") + self._sampler, self._dataloader = self._setup_data( + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + collate_fn=collate_name, + ) + + # Finally update the recipe state which can only be correctly set after all of the + # other components have been initialized and updated. + # + # Number of training steps in each epoch depends on the number of batches produced + # by the dataloader, the max_steps_per_epoch param set by the user and the + # gradient_accumulation_steps param. This value is used for logging and tracking + # training state. The computation should happen after the dataloader has been setup + self._steps_per_epoch = ( + len(self._dataloader) // self._gradient_accumulation_steps + ) + if ( + self.max_steps_per_epoch is not None + and self.max_steps_per_epoch < self._steps_per_epoch + ): + self._steps_per_epoch = self.max_steps_per_epoch + self.global_step = self.epochs_run * self._steps_per_epoch + + # Setup lr scheduler + self._lr_scheduler = self._setup_lr_scheduler( + cfg_lr_scheduler=cfg.get("lr_scheduler", None), + num_training_steps=self.total_epochs * self._steps_per_epoch, + last_epoch=self.global_step - 1, + ) + + # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) + # if cfg is missing profiler key or if `cfg.profiler.enabled = False` + self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + + # Used to ignore labels for loss computation + self.ignore_labels_cache = torch.full( + (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device + ) + + def _setup_profiler( + self, cfg_profiler: Optional[DictConfig] = None + ) -> Union[torch.profiler.profile, DummyProfiler]: + """ + Parses the `profiler` section of top-level `cfg` and sets up profiler + + Args: + cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to + `recipe.main`). Default None. + + Returns: + profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods + for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such + that the instrumented training loop does not need to be changed profiling is disabled. + + The profiler config can be provided in configs under the `profiler` key with the following layout: + + .. code-block:: yaml + profiler: + enabled: bool + + #Output directory of trace artifacts + output_dir: str + + #`torch.profiler.ProfilerActivity` types to trace + cpu: bool + cuda: bool + + #Trace options + profile_memory: bool + with_stack: bool + record_shapes: bool + with_flops: bool + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: int + warmup_steps: int + active_steps: int + num_cycles: int + """ + + # Missing profiler section in config, assume disabled + if cfg_profiler is None: + cfg_profiler = DictConfig({"enabled": False}) + + # Check that component is included and set correctly + if cfg_profiler.get("_component_", None) is None: + cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" + else: + assert ( + cfg_profiler.get("_component_") + == "torchtune.training.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" + + profiler, profiler_cfg = config.instantiate(cfg_profiler) + + log.info(f" Profiler config after instantiation: {profiler_cfg}") + + self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) + if profiler_cfg["enabled"]: + self.profiler_wait_steps = profiler_cfg["wait_steps"] + self.profiler_warmup_steps = profiler_cfg["warmup_steps"] + self.profiler_active_steps = profiler_cfg["active_steps"] + + return profiler + + def _setup_model( + self, + cfg_model: DictConfig, + enable_activation_checkpointing: bool, + enable_activation_offloading: bool, + compile_model: bool, + model_state_dict: Dict[str, Any], + ) -> nn.Module: + """ + Set up the model including enabling activation checkpointing. + """ + with training.set_default_dtype(self._dtype), self._device: + model = config.instantiate(cfg_model) + + if compile_model: + training.compile_model(model) + + if enable_activation_checkpointing: + training.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} + ) + + model.load_state_dict(model_state_dict) + + if True: + from torchao.sparsity.prototype.superblock.utils import ( + accelerate_with_sparsity, + get_args_parser, + simulate_sparsity, + ) + + superblock_args = get_args_parser(benchmark=True).parse_args([]) + superblock_args.sparsity = "bsr" + superblock_args.sparsity_linear = 0.9 + superblock_args.bsr = 64 + superblock_args.skip_attention_proj = True + simulate_sparsity(model, superblock_args) + + # Validate model was loaded in with the expected dtype. + training.validate_expected_param_dtype( + model.named_parameters(), dtype=self._dtype + ) + + # Enable activation offloading + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading + ) + + log.info(f"Model is initialized with precision {self._dtype}.") + + if self._device.type != "cpu": + memory_stats = training.get_memory_stats(device=self._device) + training.log_memory_stats(memory_stats) + + return model + + def _setup_optimizer( + self, + cfg_optimizer: DictConfig, + optimizer_in_bwd: bool = False, + opt_state_dict: Optional[Dict[str, Any]] = None, + ) -> Optional[Optimizer]: + """ + Set up the optimizer. This method also handles loading the optimizer state_dict, if specified. + """ + if optimizer_in_bwd: + # Maintain a dict of optims for every parameter. + optim_dict = { + p: config.instantiate(cfg_optimizer, [p]) + for p in self._model.parameters() + } + # Register optimizer step hooks on the model to run optimizer in backward. + training.register_optim_in_bwd_hooks( + model=self._model, optim_dict=optim_dict + ) + # Create a wrapper for checkpoint save/load of optimizer states when running in backward. + self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper( + model=self._model, optim_dict=optim_dict + ) + # Load optimizer states. If optimizer states are being restored in an optimizer in backward + # run, these need to have been saved with the same setting. Cannot restore from runs that did not + # use optimizer in backward. + if opt_state_dict is not None: + try: + self._optim_ckpt_wrapper.load_state_dict(opt_state_dict) + except BaseException as e: + raise RuntimeError( + "Failed loading in-backward optimizer checkpoints." + "Please make sure run being restored from was using in-backward optimizer." + ) from e + log.info("In-backward optimizers are set up.") + return None + else: + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + + if opt_state_dict: + optimizer.load_state_dict(opt_state_dict) + log.info("Optimizer is initialized.") + return optimizer + + def _setup_lr_scheduler( + self, + cfg_lr_scheduler: Optional[DictConfig], + num_training_steps: int, + last_epoch: int, + ) -> Optional[Optimizer]: + """ + Set up the learning rate scheduler based on the provided configuration. + It handles both standard optimization and optimizer-in-backward cases, and supports + schedulers from both torchtune.modules and torch.optim. + + Args: + cfg_lr_scheduler (Optional[DictConfig]): The learning rate scheduler configuration. + num_training_steps (int): The total number of training steps. + last_epoch (int): The index of the last epoch. + + Returns: + lr_scheduler (Optional[Optimizer]): The learning rate scheduler. + """ + if cfg_lr_scheduler is None: + log.info( + "No learning rate scheduler configured. Using constant learning rate." + ) + return None + + if self._optimizer_in_bwd: + # Use the first optimizer from the wrapper to represent the learning rate + optimizer = next(iter(self._optim_ckpt_wrapper.optim_map.values())) + else: + # Standard case: use the single optimizer + optimizer = self._optimizer + + # Instantiate the learning rate scheduler + lr_scheduler = config.instantiate( + cfg_lr_scheduler, + optimizer, + num_training_steps=num_training_steps, + last_epoch=last_epoch, + ) + + if self._optimizer_in_bwd: + # Modify the scheduler for optimizer_in_bwd case + self._optim_ckpt_wrapper.set_lr_scheduler(lr_scheduler) + + log.info("Learning rate scheduler is initialized.") + return lr_scheduler + + def _setup_data( + self, + cfg_dataset: DictConfig, + shuffle: bool, + batch_size: int, + collate_fn: str, + ) -> Tuple[DistributedSampler, DataLoader]: + """ + All data related setup happens here. Currently this recipe only supports the + DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, + iterable datasets and streaming datasets are not supported. + """ + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + packed = False + else: + ds = config.instantiate(cfg_dataset, self._tokenizer) + packed = cfg_dataset.get("packed", False) + + # Instantiate collate_fn + if "left_pad_sequence" in collate_fn: + raise RuntimeError("left_pad_sequence collator is only for inference.") + collate_fn = _get_component_from_path(collate_fn) + + sampler = DistributedSampler( + ds, + num_replicas=1, + rank=0, + shuffle=shuffle, + seed=0, + ) + dataloader = DataLoader( + dataset=ds, + batch_size=batch_size, + sampler=sampler, + # dropping last avoids shape issues with compile + flex attention + drop_last=True, + collate_fn=( + partial( + collate_fn, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else padded_collate_packed + ), + ) + + log.info("Dataset and Sampler are initialized.") + + return sampler, dataloader + + def save_checkpoint(self, epoch: int) -> None: + """ + Save state dict to file. The recipe save_checkpoint method is responsible for + correctly creating the checkpoint dict and passing to the checkpointer. + """ + ckpt_dict = {training.MODEL_KEY: self._model.state_dict()} + # if training is in-progress, checkpoint the optimizer state as well + if epoch + 1 < self.total_epochs: + ckpt_dict.update( + { + training.SEED_KEY: self.seed, + training.EPOCHS_KEY: self.epochs_run, + training.TOTAL_EPOCHS_KEY: self.total_epochs, + training.MAX_STEPS_KEY: self.max_steps_per_epoch, + } + ) + if not self._optimizer_in_bwd: + ckpt_dict[training.OPT_KEY] = self._optimizer.state_dict() + else: + ckpt_dict[training.OPT_KEY] = self._optim_ckpt_wrapper.state_dict() + self._checkpointer.save_checkpoint( + ckpt_dict, + epoch=epoch, + intermediate_checkpoint=(epoch + 1 < self.total_epochs), + ) + + def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + # Shape [b, s], needed for the loss not the model + labels = batch.pop("labels") + + with self.activations_handling_ctx: + logits = self._model(**batch) + + # Shift labels to compute loss + # equivalent to doing labels[..., 1:] and logits[..., :-1, :] + # But this way we dont need to slice the logits. We just add an ignore index to labels. + labels = torch.hstack( + (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) + ) + if not isinstance(logits, list): + labels = labels.reshape(-1) + logits = logits.reshape(-1, logits.size(-1)) + + # Compute loss + loss = self._loss_fn(logits, labels) + # free logits otherwise it peaks backward memory + del logits + + return loss + + def train(self) -> None: + """ + The core training loop. Supports training on subsets of the dataset using the + ``max_steps_per_epoch``. + """ + if self._compile: + log.info( + "NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration." + ) + # zero out the gradients before starting training + if not self._optimizer_in_bwd: + self._optimizer.zero_grad() + + # Initialize tokens count and running loss (for grad accumulation) + t0 = time.perf_counter() + running_loss = 0 + num_tokens = 0 + + self._profiler.start() + # self.epochs_run should be non-zero when we're resuming from a checkpoint + for curr_epoch in range(self.epochs_run, self.total_epochs): + # Update the sampler to ensure data is correctly shuffled across epochs + # in case shuffle is True + self._sampler.set_epoch(curr_epoch) + + pbar = tqdm(total=self._steps_per_epoch) + for idx, batch in enumerate(self._dataloader): + if ( + self.max_steps_per_epoch is not None + and (idx // self._gradient_accumulation_steps) + == self.max_steps_per_epoch + ): + break + + # Start tracking CUDA memory for active steps for just the first epoch + if ( + curr_epoch == 0 + and self.profiler_profile_memory + and idx == self.profiler_wait_steps + self.profiler_warmup_steps + ): + torch.cuda.memory._record_memory_history() + + utils.batch_to_device(batch, self._device) + + # Calculate the number of unmasked tokens in the current batch + # and increment the total number of tokens seen in the step + current_num_tokens = ( + batch["labels"] != self._loss_fn.ignore_index + ).sum() + num_tokens += current_num_tokens + + # Loss is normalized by default so we multiply by the number of tokens + # This way we can normalize by the total number of tokens if we're accumulating gradients + current_loss = self._loss_step(batch) * current_num_tokens + running_loss += current_loss + current_loss.backward() + + # Step with optimizer + if (idx + 1) % self._gradient_accumulation_steps == 0: + if not self._optimizer_in_bwd: + training.scale_grads(self._model, 1 / num_tokens) + if self._clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_( + self._model.parameters(), + max_norm=float(self._clip_grad_norm), + ) + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + + # Need to fix `lr_scheduler.step()` before `optimizer.step()` warning + if self._lr_scheduler is not None: + self._lr_scheduler.step() + self.global_step += 1 + + loss_to_log = running_loss.item() / num_tokens + pbar.update(1) + pbar.set_description( + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + ) + + # Log per-step metrics + if self.global_step % self._log_every_n_steps == 0: + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + # NOTE: for optim in backward, this assumes all optimizers have the same LR. This is currently + # true since we don't expose the ability to configure this yet. + "lr": get_lr( + ( + self._optimizer + if not self._optimizer_in_bwd + else self._optim_ckpt_wrapper + ), + ), + "tokens_per_second_per_gpu": num_tokens / time_per_step, + } + if self._device.type != "cpu" and self._log_peak_memory_stats: + log_dict.update( + training.get_memory_stats(device=self._device) + ) + if self._clip_grad_norm is not None: + log_dict.update({"grad_norm": grad_norm}) + self._metric_logger.log_dict( + log_dict, + step=self.global_step, + ) + + # Reset running stats for the next step + running_loss = 0 + num_tokens = 0 + t0 = time.perf_counter() + + # Stop tracking CUDA memory now that active steps are complete + if ( + curr_epoch == 0 + and self.profiler_profile_memory + and idx + == self.profiler_wait_steps + + self.profiler_warmup_steps + + self.profiler_active_steps + ): + torch.cuda.memory._record_memory_history(enabled=None) + + # Step the profiler + # Note we are stepping each batch, which might not include optimizer step in the trace + # if the schedule cycle doesn't align with gradient accumulation. + self._profiler.step() + + self.epochs_run += 1 + self.save_checkpoint(epoch=curr_epoch) + + self._profiler.stop() + + def cleanup(self) -> None: + self._metric_logger.close() + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in config (see available configs through ``tune ls``) + - Overwritten by arguments from the command-line + """ + config.log_config(recipe_name="FullFinetuneRecipeSingleDevice", cfg=cfg) + recipe = FullFinetuneRecipeSingleDevice(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 19e42e7cd5..cddaf8bfb9 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -167,6 +167,7 @@ def main( save: bool = False, compile: bool = True, compile_prefill: bool = False, + superblock: bool = False, profile: Optional[Path] = None, memory_profile: Optional[Path] = None, device=default_device, @@ -273,6 +274,24 @@ def main( filename = str(checkpoint_path.name).split(".")[0] torch.save(model.state_dict(), os.path.join(output_dir, filename + f"-{quantization}.pt")) + if superblock: + from torchao.sparsity.prototype.superblock.utils import ( + accelerate_with_sparsity, + get_args_parser, + simulate_sparsity, + ) + + superblock_args = get_args_parser(benchmark=True).parse_args([]) + superblock_args.sparsity = "bsr" + superblock_args.sparsity_linear = 0.9 + superblock_args.bsr = 64 + + sparsifier_or_none = simulate_sparsity(model, superblock_args) + if sparsifier_or_none is not None: + sparsifier_or_none.squash_mask() + + accelerate_with_sparsity(model, superblock_args) + if compile: print("Compiling Model") global decode_one_token, prefill @@ -426,6 +445,7 @@ def callback(x): parser.add_argument('--save', action='store_true', help='Whether to save the quantized model.') parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') + parser.add_argument('--superblock', action='store_true', help='Apply Superblock BSR sparsity') parser.add_argument('--profile', type=Path, default=None, help='Profile path.') parser.add_argument('--memory_profile', type=Path, default=None, help='filename for memory profile.') parser.add_argument('--device', type=str, default=default_device, help='Device to use') @@ -435,5 +455,5 @@ def callback(x): args = parser.parse_args() main( args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, - args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result + args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.superblock, args.profile, args.memory_profile, args.device, args.precision, args.write_result ) diff --git a/torchao/sparsity/prototype/superblock/README.md b/torchao/sparsity/prototype/superblock/README.md index 6fea1a0e3a..d0c5e52413 100644 --- a/torchao/sparsity/prototype/superblock/README.md +++ b/torchao/sparsity/prototype/superblock/README.md @@ -15,34 +15,6 @@ The BSR format is efficient for sparse matrices with a block structure, where no Currently, the BSR format is optimized for Nvidia A100 GPU(s) only. -## Setup -To use SuperBlock, you will need -* [PyTorch](https://pytorch.org/get-started/locally/) - -To train the model or evaluate accuracy, you will need: -* ImageNet2012-blurred dataset - -At least one GPU: -* A100 or H100 - -## Installation -* Clone this repo - ``` - git clone https://github.com/pytorch-labs/superblock.git - cd superblock - ``` -* Create a new conda environment - ``` - conda create -n superblock - conda activate superblock - ``` -* Install PyTorch. For best performance, we recommend the pytorch nightlies - ``` - pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 - ``` - We ran our experiments with torch==2.6.0.dev20240924+cu121 - - # Results ### Benchmarking diff --git a/torchao/sparsity/prototype/superblock/blocksparse.py b/torchao/sparsity/prototype/superblock/blocksparse.py index 69c98f6afc..a85ce448e2 100644 --- a/torchao/sparsity/prototype/superblock/blocksparse.py +++ b/torchao/sparsity/prototype/superblock/blocksparse.py @@ -92,7 +92,10 @@ def blocksparse_linear( bias: torch.Tensor, ) -> torch.Tensor: weight_bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(M, K)) - return torch.nn.functional.linear(A, weight_bsr, bias) + if A.dim() == 2: + return torch.nn.functional.linear(A, weight_bsr, bias) + else: + return torch.nn.functional.linear(A.flatten(start_dim=0, end_dim=-2), weight_bsr, bias).view(A.shape[0], A.shape[1], -1) @torch.library.register_fake("blocksparse::linear") diff --git a/torchao/sparsity/prototype/superblock/supermask.py b/torchao/sparsity/prototype/superblock/supermask.py index 0b28763445..44802348ce 100644 --- a/torchao/sparsity/prototype/superblock/supermask.py +++ b/torchao/sparsity/prototype/superblock/supermask.py @@ -7,6 +7,8 @@ import torch.nn.functional as F import numpy as np +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter + # original supermask scores_min=None scores_max=9e9 @@ -35,6 +37,21 @@ def backward(ctx, g): return g, None, None, None +class ApplyMask(torch.autograd.Function): + """Supermask STE function""" + @staticmethod + def forward(ctx, weight, scores): + return weight * scores + @staticmethod + def backward(ctx, grad_output): + grad_weight = grad_scores = None + if ctx.needs_input_grad[0]: + grad_weight = grad_output + if ctx.needs_input_grad[1]: + grad_scores = grad_output + return grad_weight, grad_scores + + class SupermaskLinear(nn.Linear): """Supermask class for Linear layer""" def __init__(self, sparsity, fixed_mask, fixed_weight, bitwidth, transform, fixed_transform, *args, **kwargs): @@ -109,7 +126,8 @@ def sparsify_offline(self): def forward(self, x): if not self.sparsify_weights: subnet = self.get_mask() - w = (self.weight*self.scale+self.shift) * subnet + w = (self.weight*self.scale+self.shift) + w = ApplyMask.apply(w, subnet) else: w = self.weight return F.linear(x, w, self.bias) @@ -179,7 +197,8 @@ def forward(self, x): subnet = subnet.repeat_interleave(self.tile_size, dim=i) subnet = torch.narrow(subnet, i, 0, k) - w = (self.weight*self.scale+self.shift) * subnet + w = (self.weight*self.scale+self.shift) + w = ApplyMask.apply(w, subnet) return F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups) def apply_supermask( @@ -190,86 +209,111 @@ def apply_supermask( conv1x1_sp_tilesize=1, conv_sparsity=0.0, conv_sp_tilesize=1, + skip_attention_proj=False, skip_last_layer_sparsity=False, skip_first_transformer_sparsity=False, device="cuda", verbose=False, ): - sparsified_modules = {} + # create filter function + # TODO: move the filtering function to the script calling this function and instead pass to apply_supermask the filter_fn + is_last_layer = lambda module, name: name == "heads.head" + is_first_transformer_layer = lambda module, name: name == "encoder.layers.encoder_layer_0" + is_attn_layer = lambda module, name: "attn" in name or "attention" in name + # TODO: create condition for ffn, k,v,q,o projections + reject_fn = lambda module, name : (skip_last_layer_sparsity and is_last_layer(module, name)) or (skip_first_transformer_sparsity and is_first_transformer_layer(module, name)) or (skip_attention_proj and is_attn_layer(module, name)) + filter_fn = lambda module, name : not reject_fn(module, name) and isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)) - for n, m in model.named_modules(): - # check conditions for skipping sparsity - if skip_last_layer_sparsity and n == "heads.head": - continue - if skip_first_transformer_sparsity and "encoder.layers.encoder_layer_0" in n: - continue - - # convert 1x1 convolutions - if conv1x1_sparsity != 0.0 and isinstance(m, torch.nn.Conv2d) and m.kernel_size == (1, 1): - new_m = SupermaskConv2d( - conv1x1_sparsity, False, False, None, None, None, - m.in_channels, - m.out_channels, - m.kernel_size, - stride=m.stride, - padding=m.padding, - dilation=m.dilation, - groups=m.groups, - bias=m.bias is not None, - padding_mode=m.padding_mode, - device=device, - tile_size=conv1x1_sp_tilesize, - ) - new_m.weight.data.copy_(m.weight.data) - if m.bias is not None: - new_m.bias.data.copy_(m.bias.data) - sparsified_modules[n] = new_m - continue + _replace_with_custom_fn_if_matches_filter( + model, + SuperMaskReplacementClass( + linear_sparsity=linear_sparsity, + linear_sp_tilesize=linear_sp_tilesize, + conv1x1_sparsity=conv1x1_sparsity, + conv1x1_sp_tilesize=conv1x1_sp_tilesize, + conv_sparsity=conv_sparsity, + conv_sp_tilesize=conv_sp_tilesize, + device=device, + verbose=verbose, + ), + filter_fn, + ) - # convert all other convolutions (not tested!) - if conv_sparsity != 0.0 and isinstance(m, torch.nn.Conv2d): - new_m = SupermaskConv2d( - conv_sparsity, False, False, None, None, None, - m.in_channels, - m.out_channels, - m.kernel_size, - stride=m.stride, - padding=m.padding, - dilation=m.dilation, - groups=m.groups, - bias=m.bias is not None, - padding_mode=m.padding_mode, - device=device, - tile_size=conv_sp_tilesize, - ) - new_m.weight.data.copy_(m.weight.data) - if m.bias is not None: - new_m.bias.data.copy_(m.bias.data) - sparsified_modules[n] = new_m - continue +class SuperMaskReplacementClass: + def __init__( + self, + linear_sparsity=0.0, + linear_sp_tilesize=1, + conv1x1_sparsity=0.0, + conv1x1_sp_tilesize=1, + conv_sparsity=0.0, + conv_sp_tilesize=1, + device="cuda", + verbose=False, + ): + self.linear_sparsity = linear_sparsity + self.linear_sp_tilesize = linear_sp_tilesize + self.conv1x1_sparsity = conv1x1_sparsity + self.conv1x1_sp_tilesize = conv1x1_sp_tilesize + self.conv_sparsity = conv_sparsity + self.conv_sp_tilesize = conv_sp_tilesize + self.device = device + self.verbose = verbose - if linear_sparsity != 0.0 and isinstance(m, torch.nn.Linear): - new_m = SupermaskLinear( - linear_sparsity, False, False, None, None, None, - m.in_features, - m.out_features, - bias=m.bias is not None, - device=device, - tile_size=linear_sp_tilesize, - ) - new_m.weight.data.copy_(m.weight.data) - if m.bias is not None: - new_m.bias.data.copy_(m.bias.data) - sparsified_modules[n] = new_m - continue + def __call__(self, module): + module_new = None - # add modules to model - for k, v in sparsified_modules.items(): - sm_name, ch_name = k.rsplit(".", 1) - sm = model.get_submodule(sm_name) - sm.add_module(ch_name, v) + if self.conv1x1_sparsity != 0.0 and isinstance(module, torch.nn.Conv2d) and module.kernel_size == (1, 1): + # convert 1x1 convolutions + module_new = SupermaskConv2d( + self.conv1x1_sparsity, False, False, None, None, None, + module.in_channels, + module.out_channels, + module.kernel_size, + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + bias=module.bias is not None, + padding_mode=module.padding_mode, + tile_size=self.conv1x1_sp_tilesize, + ).to(device=self.device, dtype=module.weight.dtype) + module_new.weight.data.copy_(module.weight.data) + if module.bias is not None: + module_new.bias.data.copy_(module.bias.data) + elif self.conv_sparsity != 0.0 and isinstance(module, torch.nn.Conv2d): + # convert all other convolutions (not tested!) + module_new = SupermaskConv2d( + self.conv_sparsity, False, False, None, None, None, + module.in_channels, + module.out_channels, + module.kernel_size, + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + bias=module.bias is not None, + padding_mode=module.padding_mode, + tile_size=self.conv_sp_tilesize, + ).to(device=self.device, dtype=module.weight.dtype) + module_new.weight.data.copy_(module.weight.data) + if module.bias is not None: + module_new.bias.data.copy_(module.bias.data) + elif self.linear_sparsity != 0.0 and isinstance(module, torch.nn.Linear): + module_new = SupermaskLinear( + self.linear_sparsity, False, False, None, None, None, + module.in_features, + module.out_features, + bias=module.bias is not None, + tile_size=self.linear_sp_tilesize, + ).to(device=self.device, dtype=module.weight.dtype) + module_new.weight.data.copy_(module.weight.data) + if module.bias is not None: + module_new.bias.data.copy_(module.bias.data) + else: + return module - if verbose: - print(f'sparsified module "{k}" with sparsity={v.sparsity}, tile size={v.tile_size}') + if self.verbose: + print(f'sparsified module "{module}" with sparsity={module_new.sparsity}, tile size={module_new.tile_size}') - return model + return module_new diff --git a/torchao/sparsity/prototype/superblock/utils.py b/torchao/sparsity/prototype/superblock/utils.py index cf865fd369..c77e900670 100644 --- a/torchao/sparsity/prototype/superblock/utils.py +++ b/torchao/sparsity/prototype/superblock/utils.py @@ -120,7 +120,7 @@ def mlp_only(mod, name): def superblock_only(mod, name): - return isinstance(mod, SupermaskLinear) and "mlp" in name + return isinstance(mod, SupermaskLinear)# and "mlp" in name def mlp_only_with_args( @@ -138,7 +138,7 @@ def mlp_only_with_args( ### Custom sparsification utils def apply_sparsity(model): for name, module in model.named_modules(): - if isinstance(module, SupermaskLinear) and "mlp" in name: + if isinstance(module, SupermaskLinear):# and "mlp" in name: # TODO: add option in another function for "mlp" in name module.sparsify_offline() @@ -185,6 +185,7 @@ def simulate_sparsity(model, args): conv1x1_sp_tilesize=args.bsr, conv_sparsity=args.sparsity_conv, conv_sp_tilesize=args.bsr, + skip_attention_proj=args.skip_attention_proj, skip_last_layer_sparsity=args.skip_last_layer_sparsity, skip_first_transformer_sparsity=args.skip_first_transformer_sparsity, device=args.device,