Skip to content

Bump version for float8 dynamic quant and weight only quant configs #2650

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions test/core/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json
import os
import tempfile
import warnings
from dataclasses import dataclass
from unittest import mock

Expand All @@ -15,7 +16,6 @@

from torchao.core.config import (
AOBaseConfig,
VersionMismatchError,
config_from_dict,
config_to_dict,
)
Expand Down Expand Up @@ -151,7 +151,9 @@ def test_reconstructable_dict_file_round_trip(config):
# Define a dummy config in a non-allowed module
@dataclass
class DummyNonAllowedConfig(AOBaseConfig):
VERSION = 2
# NOTE: must be `version: int` (with type annotations) to
# overload the version variable from AOBaseConfig
version: int = 2
value: int = 42


Expand All @@ -172,11 +174,11 @@ def test_disallowed_modules():
reconstructed = config_from_dict(reconstructable)
assert isinstance(reconstructed, DummyNonAllowedConfig)
assert reconstructed.value == 42
assert reconstructed.VERSION == 2
assert reconstructed.version == 2


def test_version_mismatch():
"""Test that version mismatch raises an error during reconstruction."""
"""Test that version mismatch prints a warning during reconstruction."""
# Create a config
dummy_config = DummyNonAllowedConfig()
reconstructable = config_to_dict(dummy_config)
Expand All @@ -186,25 +188,27 @@ def test_version_mismatch():

# Patch to allow the module but should still fail due to version mismatch
with mock.patch("torchao.core.config.ALLOWED_AO_MODULES", {__name__}):
with pytest.raises(
VersionMismatchError,
match="Version mismatch for DummyNonAllowedConfig: stored version 1 != current version 2",
):
with warnings.catch_warnings(record=True) as caught_warnings:
config_from_dict(reconstructable)
assert any(
"Stored version is not the same as current default version of the config"
in str(w.message)
for w in caught_warnings
), "Didn't get expected warning message for version mismatch"


def test_default_version():
"""Making sure the default version for a new config inheriting from AOBaseConfig is always 1
because it's the default VERSION that all children has when they haven't explicitly
defined a VERSION class variable
because it's the default version that all children has when they haven't explicitly
defined a version class variable
"""

@dataclass
class DummyConfig(AOBaseConfig):
pass

config = DummyConfig()
assert config.VERSION == 1, "Default version must be 1"
assert config.version == 1, "Default version must be 1"


if __name__ == "__main__":
Expand Down
76 changes: 52 additions & 24 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,14 @@
from torchao.float8.float8_utils import compute_error
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
float8_dynamic_activation_float8_weight,
float8_weight_only,
Float8StaticActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
quantize_,
)
from torchao.quantization.granularity import (
PerRow,
PerTensor,
)
from torchao.quantization.quant_api import (
float8_static_activation_float8_weight,
)
from torchao.quantization.quant_primitives import (
MappingType,
_choose_scale_float8,
Expand Down Expand Up @@ -119,11 +116,13 @@ def test_fp8_linear_variants(
)
mode_map = {
"dynamic": partial(
float8_dynamic_activation_float8_weight, granularity=granularity
Float8DynamicActivationFloat8WeightConfig,
granularity=granularity,
version=1,
),
"weight-only": float8_weight_only,
"weight-only": partial(Float8WeightOnlyConfig, version=1),
"static": partial(
float8_static_activation_float8_weight,
Float8StaticActivationFloat8WeightConfig,
scale=scale,
granularity=granularity,
),
Expand Down Expand Up @@ -152,7 +151,7 @@ def test_fp8_linear_variants(
)
def test_invalid_granularity(self):
with pytest.raises(ValueError, match="Invalid granularity specification"):
float8_dynamic_activation_float8_weight(granularity="invalid")
Float8DynamicActivationFloat8WeightConfig(granularity="invalid")

@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
Expand All @@ -162,7 +161,9 @@ def test_mismatched_granularity(self):
ValueError,
match="Different granularities for activation and weight are not supported",
):
float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow()))
Float8DynamicActivationFloat8WeightConfig(
granularity=(PerTensor(), PerRow())
)

@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
Expand All @@ -172,8 +173,8 @@ class UnsupportedGranularity:
pass

with pytest.raises(ValueError, match="Invalid granularity types"):
float8_dynamic_activation_float8_weight(
granularity=(UnsupportedGranularity(), UnsupportedGranularity())
Float8DynamicActivationFloat8WeightConfig(
granularity=(UnsupportedGranularity(), UnsupportedGranularity()),
)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand All @@ -187,7 +188,8 @@ def test_per_row_with_float32(self):
):
model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda")
quantize_(
model, float8_dynamic_activation_float8_weight(granularity=PerRow())
model,
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand All @@ -201,15 +203,18 @@ def test_serialization(self, mode: str):

mode_map = {
"dynamic": partial(
float8_dynamic_activation_float8_weight, granularity=PerTensor()
Float8DynamicActivationFloat8WeightConfig,
granularity=PerTensor(),
version=1,
),
"weight-only": float8_weight_only,
"weight-only": partial(Float8WeightOnlyConfig, version=1),
"static": partial(
float8_static_activation_float8_weight,
Float8StaticActivationFloat8WeightConfig,
scale=torch.tensor(1.0, dtype=torch.float32, device="cuda"),
granularity=PerTensor(),
),
}

factory = mode_map[mode]()
quantize_(model, factory)

Expand Down Expand Up @@ -275,7 +280,10 @@ def test_fp8_weight_dimension_warning(self):
"torchao.quantization.quant_api", level="INFO"
) as log_context:
quantize_(
model, float8_dynamic_activation_float8_weight(granularity=PerTensor())
model,
Float8DynamicActivationFloat8WeightConfig(
granularity=PerTensor(), version=1
),
)
print(model)

Expand Down Expand Up @@ -320,7 +328,8 @@ def test_mm_float8dq_per_row(
)
test_linear = copy.deepcopy(ref_linear)
quantize_(
test_linear, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
test_linear,
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow(), version=1),
)

quant_weight = test_linear.weight
Expand Down Expand Up @@ -472,7 +481,10 @@ def test_float8_tensor_slicing_basic(self, granularity):
# Create and quantize a model
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
quantize_(
model, Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
model,
Float8DynamicActivationFloat8WeightConfig(
granularity=granularity, version=1
),
)

weight_impl = model.weight.original_weight_tensor.tensor_impl
Expand Down Expand Up @@ -506,7 +518,10 @@ def test_float8_tensor_slicing_per_tensor(self):
# Create and quantize with per-tensor granularity
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
quantize_(
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor())
model,
Float8DynamicActivationFloat8WeightConfig(
granularity=PerTensor(), version=1
),
)

original_weight = model.weight
Expand Down Expand Up @@ -537,7 +552,8 @@ def test_float8_tensor_slicing_per_row(self):
# Create and quantize with per-row granularity
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
quantize_(
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
model,
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow(), version=1),
)

original_weight = model.weight # Shape: (32, 64)
Expand Down Expand Up @@ -575,7 +591,10 @@ def test_float8_tensor_slicing_edge_cases(self):
# Create and quantize a model
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
quantize_(
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor())
model,
Float8DynamicActivationFloat8WeightConfig(
granularity=PerTensor(), version=1
),
)

original_weight = model.weight
Expand Down Expand Up @@ -613,7 +632,9 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity):
quant_model = copy.deepcopy(ref_model)
quantize_(
quant_model,
Float8DynamicActivationFloat8WeightConfig(granularity=granularity),
Float8DynamicActivationFloat8WeightConfig(
granularity=granularity, version=1
),
)

# Create input with batch size that works well with slicing
Expand Down Expand Up @@ -720,6 +741,7 @@ def test_preprocess_scale_3d_reshape(self):
self.assertEqual(result.shape, expected_shape)

@torch.no_grad()
@unittest.skip("test is flaky in CI, will turn on a bit later")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(
not is_sm_at_least_90(), "Requires GPU with compute capability >= 9.0"
Expand All @@ -743,7 +765,13 @@ def test_expected_kernels_on_gpu(self, granularity, torch_compile_mode):
m = torch.nn.Sequential(
torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16)
)
quantize_(m, Float8DynamicActivationFloat8WeightConfig(granularity=granularity))
quantize_(
m,
Float8DynamicActivationFloat8WeightConfig(
granularity=granularity, version=1
),
)

m = torch.compile(m, mode=torch_compile_mode)
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)

Expand Down
6 changes: 3 additions & 3 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,10 +473,10 @@ def test_quantize(self):
m = nn.Sequential(nn.Linear(32, 32)).cuda()
m = convert_to_float8_training(m)
assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear"
from torchao.quantization.quant_api import float8_weight_only, quantize_
from torchao.quantization import Float8WeightOnlyConfig, quantize_

quantize_(m, float8_weight_only())
assert m[0].weight.tensor_impl.float8_data.dtype == torch.float8_e4m3fn, (
quantize_(m, Float8WeightOnlyConfig())
assert m[0].weight.qdata.dtype == torch.float8_e4m3fn, (
"Post quantization dtype should be torch.float8_e4m3fn"
)
with torch.no_grad():
Expand Down
70 changes: 70 additions & 0 deletions test/integration/test_loading_deprecated_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import warnings

import torch
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig

from torchao.utils import is_sm_at_least_89

_MODEL_NAME_AND_VERSIONS = [
("torchao-testing/opt-125m-float8dq-row-v1-0.13-dev", 1),
]


@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_at_least_89(), "Nedd sm89+")
class TestLoadingDeprecatedCheckpoint(TestCase):
@common_utils.parametrize("model_name_and_version", _MODEL_NAME_AND_VERSIONS)
def test_load_model_and_run(self, model_name_and_version):
"""Test that we print correct warning message when loading a deprecated checkpoint"""
# Load and quantize model
model_name, version = model_name_and_version
with warnings.catch_warnings(record=True) as caught_warnings:
quantized_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="bfloat16",
device_map="cuda",
)
assert any(
"Stored version is not the same as current default version of the config"
in str(w.message)
for w in caught_warnings
), "Didn't get expected warning message for version mismatch"

assert any(
"Models quantized with version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated"
in str(w.message)
for w in caught_warnings
), "Didn't get expected warning message for deprecation"
assert isinstance(quantized_model.config.quantization_config, TorchAoConfig)
assert (
quantized_model.config.quantization_config.quant_type.version == version
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
prompt = ("Hello, my name is",)
inputs = tokenizer(
prompt,
return_tensors="pt",
).to("cuda")
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128)
# make sure it runs
_ = tokenizer.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)


common_utils.instantiate_parametrized_tests(TestLoadingDeprecatedCheckpoint)

if __name__ == "__main__":
run_tests()
Loading
Loading