Skip to content

Commit 3516654

Browse files
tarun292pull[bot]
authored andcommitted
Save quantization_tag in export graph serialization (pytorch#127473)
Summary: `quantization_tag` is a first class citizen metadata in quantization flows that is preserved by it. As we'll want to store the quantized exported graphs we also need to preserve this metadata as it's used in later flows. Only json supported metadata will be allowed to be serialized. Test Plan: Added test case Differential Revision: D57939282 Pull Request resolved: pytorch#127473 Approved by: https://github.com/angelayi
1 parent e11ddd3 commit 3516654

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

test/export/test_serialize.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1190,6 +1190,27 @@ def forward(self, x):
11901190
ep = deserialize(serialized_vals)
11911191
self.assertTrue(isinstance(ep.constants["custom_obj"].get(), FakeTensor))
11921192

1193+
def test_quantization_tag_metadata(self):
1194+
class Foo(torch.nn.Module):
1195+
def forward(self, x):
1196+
return x + x
1197+
1198+
f = Foo()
1199+
1200+
inputs = (torch.zeros(4, 4),)
1201+
ep = export(f, inputs)
1202+
1203+
for node in ep.graph.nodes:
1204+
if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
1205+
node.meta["quantization_tag"] = "foo"
1206+
1207+
serialized_vals = serialize(ep)
1208+
ep = deserialize(serialized_vals)
1209+
1210+
for node in ep.graph.nodes:
1211+
if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
1212+
self.assertTrue(node.meta["quantization_tag"] == "foo")
1213+
11931214

11941215
if __name__ == "__main__":
11951216
run_tests()

torch/_export/serde/serialize.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,9 @@ def export_nn_module_stack(val):
594594
if torch_fn := node.meta.get("torch_fn"):
595595
ret["torch_fn"] = ST_DELIMITER.join(list(torch_fn))
596596

597+
if quantization_tag := node.meta.get("quantization_tag"):
598+
ret["quantization_tag"] = json.dumps(quantization_tag)
599+
597600
return ret
598601

599602
def serialize_script_obj_meta(
@@ -2149,6 +2152,10 @@ def metadata_split(metadata):
21492152

21502153
if torch_fn_str := metadata.get("torch_fn"):
21512154
ret["torch_fn"] = tuple(torch_fn_str.split(ST_DELIMITER))
2155+
2156+
if quantization_tag_str := metadata.get("quantization_tag"):
2157+
ret["quantization_tag"] = json.loads(quantization_tag_str)
2158+
21522159
return ret
21532160

21542161
def deserialize_argument_spec(self, x: Argument) -> ep.ArgumentSpec:

0 commit comments

Comments
 (0)