5
5
from typing import Any , Dict
6
6
7
7
import torch
8
+ import torchao
8
9
9
10
from torchao .quantization import Float8Tensor
10
-
11
- ALLOWED_AO_MODULES = {
12
- "torchao.quantization" ,
13
- "torchao.dtypes" ,
14
- "torchao.quantization.quantize_.common" ,
15
- "torchao.quantization.quantize_.workflows" ,
11
+ from torchao .quantization .quantize_ .workflows import QuantizeTensorToFloat8Kwargs
12
+ from torchao .quantization .quantize_ .common import KernelPreference
13
+
14
+ ALLOWED_CLASSES = {
15
+ "Float8Tensor" : torchao .quantization .Float8Tensor ,
16
+ "Float8MMConfig" : torchao .float8 .inference .Float8MMConfig ,
17
+ "QuantizeTensorToFloat8Kwargs" : QuantizeTensorToFloat8Kwargs ,
18
+ "PerRow" : torchao .quantization .granularity .PerRow ,
19
+ "KernelPreference" : KernelPreference ,
16
20
}
17
21
18
22
@@ -58,7 +62,7 @@ def default(self, o):
58
62
return {"_type" : "torch.dtype" , "_data" : str (o ).split ("." )[- 1 ]}
59
63
60
64
if isinstance (o , enum .Enum ):
61
- # Store the full path for enums to ensure uniqueness
65
+ # Store the full class name for enums to ensure uniqueness
62
66
return {"_type" : f"{ o .__class__ .__name__ } " , "_data" : o .name }
63
67
64
68
if isinstance (o , list ):
@@ -81,6 +85,8 @@ def encode_value(self, value):
81
85
except TypeError :
82
86
pass
83
87
88
+ # Default case - return as is
89
+ # (This will be processed by standard JSON encoder later)
84
90
return value
85
91
86
92
@@ -97,19 +103,11 @@ def object_from_dict(data: Dict[str, Any]):
97
103
if type_path == "torch.dtype" :
98
104
return getattr (torch , obj_data )
99
105
100
- # Try to find the class in any of the allowed modules
101
- cls = None
102
- for module_path in ALLOWED_AO_MODULES :
103
- try :
104
- module = importlib .import_module (module_path )
105
- cls = getattr (module , type_path )
106
- break # Found the class, exit the loop
107
- except (ImportError , AttributeError ):
108
- continue # Try the next module
106
+ cls = ALLOWED_CLASSES .get (type_path )
109
107
110
108
# If we couldn't find the class in any allowed module, raise an error
111
109
if cls is None :
112
- allowed_modules_str = ", " .join (ALLOWED_AO_MODULES )
110
+ allowed_modules_str = ", " .join (ALLOWED_CLASSES )
113
111
raise ValueError (
114
112
f"Failed to find class { type_path } in any of the allowed modules: { allowed_modules_str } "
115
113
)
0 commit comments