Skip to content

Commit 8d89521

Browse files
committed
addressing comments
1 parent 62baf24 commit 8d89521

File tree

3 files changed

+18
-20
lines changed

3 files changed

+18
-20
lines changed

test/prototype/safetensors/test_safetensors.py renamed to test/prototype/safetensors/test_safetensors_support.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
)
99

1010
from torchao import quantize_
11-
from torchao.prototype.quantization.safetensors_support import (
11+
from torchao.prototype.safetensors.safetensors_support import (
1212
load_tensor_state_dict,
1313
save_tensor_state_dict,
1414
)

torchao/prototype/quantization/safetensors_serialization.py renamed to torchao/prototype/safetensors/safetensors_serialization.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,18 @@
55
from typing import Any, Dict
66

77
import torch
8+
import torchao
89

910
from torchao.quantization import Float8Tensor
10-
11-
ALLOWED_AO_MODULES = {
12-
"torchao.quantization",
13-
"torchao.dtypes",
14-
"torchao.quantization.quantize_.common",
15-
"torchao.quantization.quantize_.workflows",
11+
from torchao.quantization.quantize_.workflows import QuantizeTensorToFloat8Kwargs
12+
from torchao.quantization.quantize_.common import KernelPreference
13+
14+
ALLOWED_CLASSES = {
15+
"Float8Tensor": torchao.quantization.Float8Tensor,
16+
"Float8MMConfig": torchao.float8.inference.Float8MMConfig,
17+
"QuantizeTensorToFloat8Kwargs": QuantizeTensorToFloat8Kwargs,
18+
"PerRow": torchao.quantization.granularity.PerRow,
19+
"KernelPreference": KernelPreference,
1620
}
1721

1822

@@ -58,7 +62,7 @@ def default(self, o):
5862
return {"_type": "torch.dtype", "_data": str(o).split(".")[-1]}
5963

6064
if isinstance(o, enum.Enum):
61-
# Store the full path for enums to ensure uniqueness
65+
# Store the full class name for enums to ensure uniqueness
6266
return {"_type": f"{o.__class__.__name__}", "_data": o.name}
6367

6468
if isinstance(o, list):
@@ -81,6 +85,8 @@ def encode_value(self, value):
8185
except TypeError:
8286
pass
8387

88+
# Default case - return as is
89+
# (This will be processed by standard JSON encoder later)
8490
return value
8591

8692

@@ -97,19 +103,11 @@ def object_from_dict(data: Dict[str, Any]):
97103
if type_path == "torch.dtype":
98104
return getattr(torch, obj_data)
99105

100-
# Try to find the class in any of the allowed modules
101-
cls = None
102-
for module_path in ALLOWED_AO_MODULES:
103-
try:
104-
module = importlib.import_module(module_path)
105-
cls = getattr(module, type_path)
106-
break # Found the class, exit the loop
107-
except (ImportError, AttributeError):
108-
continue # Try the next module
106+
cls = ALLOWED_CLASSES.get(type_path)
109107

110108
# If we couldn't find the class in any allowed module, raise an error
111109
if cls is None:
112-
allowed_modules_str = ", ".join(ALLOWED_AO_MODULES)
110+
allowed_modules_str = ", ".join(ALLOWED_CLASSES)
113111
raise ValueError(
114112
f"Failed to find class {type_path} in any of the allowed modules: {allowed_modules_str}"
115113
)

torchao/prototype/quantization/safetensors_support.py renamed to torchao/prototype/safetensors/safetensors_support.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from safetensors.torch import load_file, save_file
77

8-
from torchao.prototype.quantization.safetensors_serialization import (
8+
from torchao.prototype.safetensors.safetensors_serialization import (
99
Float8TensorAttributeJSONEncoder,
1010
object_from_dict,
1111
)
@@ -118,7 +118,7 @@ def save_tensor_state_dict(
118118
tensors_dict[tensor_data_name] = getattr(tensor, tensor_data_name)
119119

120120
metadata = json.dumps(tensor, cls=Float8TensorAttributeJSONEncoder)
121-
elif isinstance(tensor, torch.Tensor):
121+
elif type(tensor) is torch.Tensor:
122122
tensors_dict = {"_data": tensor}
123123
metadata = json.dumps({"_type": torch.Tensor.__name__})
124124
else:

0 commit comments

Comments
 (0)