From e929c1936beafe836c75cf09da8b215ed542c478 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 18 Jun 2025 14:52:02 -0700 Subject: [PATCH 1/3] add float auto_filter_for_recipe --- docs/float8.md | 2 + torchtitan/components/quantization/float8.py | 112 +++++++++++++------ 2 files changed, 81 insertions(+), 33 deletions(-) diff --git a/docs/float8.md b/docs/float8.md index 63a029e60..a3d806c92 100644 --- a/docs/float8.md +++ b/docs/float8.md @@ -17,6 +17,8 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_trai * `--float8.enable_fsdp_float8_all_gather`: cast `Float8Linear.weight` from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth. * `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter. * `--float8.force_recompute_fp8_weight_in_bwd` (optional): force recomputation of fp8 weights during backward pass, preventing unsharded fp8 weights from being saved for backward. +* `--float8.filter_fqns="..."` (optional): a comma separated list of fully qualified names of modules not to convert to float8 training. Example: `--float8.filter_fqns="attention.wk,attention.wv"`. You can determine which layers to convert by looking at the microbenchmarks in the [performance section](https://github.com/pytorch/ao/tree/main/torchao/float8#performance) of the torchao documentation for the float8 recipe you're using. + * **Auto-filter**: add `"auto_filter_small_kn"` as one of the `--float8.filter_fqns=...` to to enable automatic module filtering, which will automatically not convert linear layers are not large enough to benefit from float8 training, since the GEMM has to be big enough that the speedup from using FP8 tensorcores is greater than the overhead of creating dynamically quantized inputs. The thresholds for conversion are based on microbenchmarks measured on NVIDIA H100 GPUs, where (K,N) represents the linear layer weight shape. For best performance, you should still manually filter out layers that are too small to benefit from float8 training. * `--training.compile` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels For float8 with rowwise scaling, launch training job with the following command (or alternatively set configs in toml files) diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index 782889716..29f46084d 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -3,7 +3,6 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - from functools import partial import torch @@ -20,6 +19,8 @@ from .utils import module_filter_fn +AUTO_FILTER_SMALL_KN_FLAG = "auto_filter_small_kn" + class Float8Converter(ModelConverter): def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): @@ -54,14 +55,19 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): self.enabled = True self.filter_fqns = float8_config.filter_fqns self.moe_fqns = float8_config.moe_fqns_prototype + self.filter_fn = self._init_filter_fn(float8_config) if float8_config.recipe_name is not None: - assert ( - not float8_config.enable_fsdp_float8_all_gather - ), "using `float8_config.enable_fsdp_float8_all_gather` together with `float8_config.recipe_name` is not supported" - assert ( - not float8_config.force_recompute_fp8_weight_in_bwd - ), "using `float8_config.force_recompute_fp8_weight_in_bwd` together with `float8_config.recipe_name` is not supported" + assert not float8_config.enable_fsdp_float8_all_gather, ( + "using `float8_config.enable_fsdp_float8_all_gather` together " + "with `float8_config.recipe_name` is not supported" + ) + + assert not float8_config.force_recompute_fp8_weight_in_bwd, ( + "using `float8_config.force_recompute_fp8_weight_in_bwd` together " + "with `float8_config.recipe_name` is not supported" + ) + self.config = Float8LinearConfig.from_recipe_name(float8_config.recipe_name) self.precompute_scale = False logger.info( @@ -74,7 +80,6 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): logger.debug( "Set torch._inductor.config.emulate_precision_casts to True" ) - else: # Mutates the model inplace replacing instances of nn.Linear with Float8Linear enable_fsdp_float8_all_gather = ( @@ -93,6 +98,41 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): ) logger.info("Float8 tensorwise scaled training active") + def _init_filter_fn(self, float8_config: Float8): + # use auto_filter if filter_fqns "auto_filter_small_kn" is one of the given fqns. + use_auto_filter = AUTO_FILTER_SMALL_KN_FLAG in float8_config.filter_fqns + if use_auto_filter: + try: + from torchao.float8 import _auto_filter_for_recipe + + logger.info( + "Using automatic module filter for float8 model conversion." + ) + + recipe_name = ( + float8_config.recipe_name + if float8_config.recipe_name + else "tensorwise" + ) + + # remove auto filter flag from filter_fqns before passing to _auto_filter_for_recipe + float8_config.filter_fqns.remove(AUTO_FILTER_SMALL_KN_FLAG) + + return _auto_filter_for_recipe( + recipe_name, + filter_fqns=float8_config.filter_fqns, + ) + except ImportError: + logger.warning( + ( + "Using default module_filter_fn for float8 model conversion. " + "To use _auto_filter_for_recipe, please install torchao nightly build." + ) + ) + + # use default filter func + return partial(module_filter_fn, filter_fqns=float8_config.filter_fqns) + def convert(self, model: nn.Module): """ This function converts the linear layers of `model` to `Float8Linear`. @@ -102,36 +142,12 @@ def convert(self, model: nn.Module): if not self.enabled: return - # Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor, - # to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs. # MoE conversion must take place before Float8Linear conversion, otherwise the Float8Linears will # be converted back to nn.Linear: # https://github.com/pytorch/ao/blob/c2a6568a04075acc371a338206216bb65536fb27/torchao/quantization/quant_api.py#L294-L299 # TODO: add warning in torchao when this happens, or find a better way to avoid this. if self.moe_fqns: - from torchao.quantization.quant_api import quantize_ - - try: - from torchao.prototype.moe_training.conversion_utils import ( - MoETrainingConfig, - ) - except ImportError as e: - raise ImportError( - "torchao installation does not have MoE training support. Please install torchao nightly build." - ) from e - - def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: - for target_fqn in self.moe_fqns: - if target_fqn in cur_fqn: - return True - return False - - config = MoETrainingConfig() - quantize_(model, config=config, filter_fn=moe_module_filter_fn) - logger.info( - f"Converted MoE layers matching FQNS {self.moe_fqns} " - "to use dynamic float8 rowwise quantization with scaled grouped GEMMs" - ) + self._convert_moe_layers(model) from torchao.float8 import convert_to_float8_training @@ -146,6 +162,36 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: f"{self.config.enable_fsdp_float8_all_gather}" ) + def _convert_moe_layers(self, model: nn.Module): + """ + Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor, + to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs. + """ + from torchao.quantization.quant_api import quantize_ + + try: + from torchao.prototype.moe_training.conversion_utils import ( + MoETrainingConfig, + ) + except ImportError as e: + raise ImportError( + "torchao installation does not have MoE training support. Please install torchao nightly build." + ) from e + + def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: + for target_fqn in self.moe_fqns: + if target_fqn in cur_fqn: + return True + return False + + config = MoETrainingConfig() + quantize_(model, config=config, filter_fn=moe_module_filter_fn) + logger.info( + f"Converted MoE layers matching FQNS {self.moe_fqns} " + "to use dynamic float8 rowwise quantization with scaled grouped GEMMs" + ) + + def post_optimizer_hook(self, model: nn.Module | list[nn.Module]): if not self.enabled: return From c5cb3c55717d7d90779115cd3653fe927a9d9cd4 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Mon, 30 Jun 2025 09:05:44 -0700 Subject: [PATCH 2/3] add better logging --- torchtitan/components/quantization/float8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index 29f46084d..39dd03765 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -106,7 +106,7 @@ def _init_filter_fn(self, float8_config: Float8): from torchao.float8 import _auto_filter_for_recipe logger.info( - "Using automatic module filter for float8 model conversion." + "Using _auto_filter_for_recipe to avoid converting linear layers with dims too small to benefit from float8 training. See docs/float8.md for more info." ) recipe_name = ( From 9c79cfd6dfc89311a372768dfc0e2c2ebc7c3e1b Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Mon, 30 Jun 2025 09:29:05 -0700 Subject: [PATCH 3/3] lint --- torchtitan/components/quantization/float8.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index 39dd03765..4d31a43ff 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -106,7 +106,8 @@ def _init_filter_fn(self, float8_config: Float8): from torchao.float8 import _auto_filter_for_recipe logger.info( - "Using _auto_filter_for_recipe to avoid converting linear layers with dims too small to benefit from float8 training. See docs/float8.md for more info." + "Using _auto_filter_for_recipe to avoid converting linear layers with dims too small " + "to benefit from float8 training. See docs/float8.md for more info." ) recipe_name = ( @@ -191,7 +192,6 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: "to use dynamic float8 rowwise quantization with scaled grouped GEMMs" ) - def post_optimizer_hook(self, model: nn.Module | list[nn.Module]): if not self.enabled: return