Skip to content

pytorch/ao

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

TorchAO

PyTorch-Native Training-to-Serving Model Optimization

  • Pre-train Llama-3.1-70B 1.5x faster with float8 training
  • Recover 67% of quantized accuracy degradation on Gemma3-4B with QAT
  • Quantize Llama-3-8B to int4 for 1.89x faster inference with 58% less memory

πŸ“£ Latest News

Older news

πŸŒ… Overview

TorchAO is an easy to use quantization library for native PyTorch. TorchAO works out-of-the-box with torch.compile() and FSDP2 across most HuggingFace PyTorch models. Key features include:

Check out our docs for more details!

πŸš€ Quick Start

First, install TorchAO. We recommend installing the latest stable version:

pip install torchao

Quantize your model weights to int4!

from torchao.quantization import Int4WeightOnlyConfig, quantize_
quantize_(model, Int4WeightOnlyConfig(group_size=32, int4_packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq"))

See our quick start guide for more details.

πŸ›  Installation

To install the latest stable version:

pip install torchao
Other installation options
# Nightly
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu128

# Different CUDA versions
pip install torchao --index-url https://download.pytorch.org/whl/cu126  # CUDA 12.6
pip install torchao --index-url https://download.pytorch.org/whl/cu129  # CUDA 12.9
pip install torchao --index-url https://download.pytorch.org/whl/cpu    # CPU only

# For developers
# Note: the `--no-build-isolation` flag is required.
USE_CUDA=1 pip install -e . --no-build-isolation
USE_CPP=0 pip install -e . --no-build-isolation

Please see the torchao compability table for version requirements for dependencies.

πŸ”— Integrations

TorchAO is integrated into some of the leading open-source libraries including:

πŸ”Ž Inference

TorchAO delivers substantial performance gains with minimal code changes:

Following is our recommended flow for quantization and deployment:

from transformers import TorchAoConfig, AutoModelForCausalLM
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow

# Create quantization configuration
quantization_config = TorchAoConfig(quant_type=Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()))

# Load and automatically quantize
quantized_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-32B",
    dtype="auto",
    device_map="auto",
    quantization_config=quantization_config
)

Alternative quantization API to use when the above doesn't work is quantize_ API in quick start guide.

Serving with vllm on 1xH100 machine:

# Server
VLLM_DISABLE_COMPILE_CACHE=1 vllm serve pytorch/Qwen3-32B-FP8 --tokenizer Qwen/Qwen3-32B -O3
# Client
curl http://localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
  "model": "pytorch/Qwen3-32B-FP8",
  "messages": [
    {"role": "user", "content": "Give me a short introduction to large language models."}
  ],
  "temperature": 0.6,
  "top_p": 0.95,
  "top_k": 20,
  "max_tokens": 32768
}'

We also support deployment to edge devices through ExecuTorch, for more detail, see quantization and serving guide. We also release pre-quantized models here.

πŸš… Training

Quantization-Aware Training

Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization-Aware Training (QAT) to overcome this limitation, especially for lower bit-width dtypes such as int4. In collaboration with TorchTune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering 96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext for Llama3 compared to post-training quantization (PTQ). For more details, please refer to the QAT README and the original blog:

from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig
from torchao.quantization.qat import QATConfig

# prepare
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
quantize_(my_model, QATConfig(base_config, step="prepare"))

# train model (not shown)

# convert
quantize_(my_model, QATConfig(base_config, step="convert"))

Users can also combine LoRA + QAT to speed up training by 1.89x compared to vanilla QAT using this fine-tuning recipe.

Quantized training

torchao.float8 implements training recipes with the scaled float8 dtypes, as laid out in https://arxiv.org/abs/2209.05433. With torch.compile on, current results show throughput speedups of up to 1.5x on up to 512 GPU / 405B parameter count scale (details):

from torchao.float8 import convert_to_float8_training
convert_to_float8_training(m)

Our float8 training is integrated into TorchTitan's pre-training flows so users can easily try it out. For more details, check out these blog posts about our float8 training support:

Other features (sparse training, memory efficient optimizers)

Sparse Training

We've added support for semi-structured 2:4 sparsity with 6% end-to-end speedups on ViT-L. Full blog here. The code change is a 1 liner with the full example available here:

from torchao.sparsity.training import SemiSparseLinear, swap_linear_with_semi_sparse_linear
swap_linear_with_semi_sparse_linear(model, {"seq.0": SemiSparseLinear})

Memory-efficient optimizers

Optimizers like ADAM can consume substantial GPU memory - 2x as much as the model parameters themselves. TorchAO provides two approaches to reduce this overhead:

1. Quantized optimizers: Reduce optimizer state memory by 2-4x by quantizing to lower precision

from torchao.optim import AdamW8bit, AdamW4bit, AdamWFp8
optim = AdamW8bit(model.parameters()) # replace with Adam4bit and AdamFp8 for the 4 / fp8 versions

Our quantized optimizers are implemented in just a few hundred lines of PyTorch code and compiled for efficiency. While slightly slower than specialized kernels, they offer an excellent balance of memory savings and performance. See detailed benchmarks here.

2. CPU offloading: Move optimizer state and gradients to CPU memory

For maximum memory savings, we support single GPU CPU offloading that efficiently moves both gradients and optimizer state to CPU memory. This approach can reduce your VRAM requirements by 60% with minimal impact on training speed:

optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, fused=True)
optim.load_state_dict(ckpt["optim"])

πŸŽ₯ Videos

πŸ’¬ Citation

If you find the torchao library useful, please cite it in your work as below.

@software{torchao,
  title={TorchAO: PyTorch-Native Training-to-Serving Model Optimization},
  author={torchao},
  url={https://github.com/pytorch/ao},
  license={BSD-3-Clause},
  month={oct},
  year={2024}
}