Skip to content

Commit 7ff4af0

Browse files
shubhambhokare1BowenBao
authored andcommitted
[ONNX] Add module name as PythonOp attribute (pytorch#67193)
* Add module name as pythonOp attr * Move to trace_post_record * Add tests * Code compactness
1 parent 792ec61 commit 7ff4af0

File tree

4 files changed

+55
-0
lines changed

4 files changed

+55
-0
lines changed

aten/src/ATen/core/interned_strings.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ namespace c10 {
297297
_(attr, transA) \
298298
_(attr, transB) \
299299
_(attr, name) \
300+
_(attr, module) \
300301
_(attr, beg) \
301302
_(attr, idx) \
302303
_(attr, split) \

test/onnx/autograd_helper.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Owner(s): ["module: onnx"]
2+
3+
import torch
4+
5+
# Autograd funtion that is a replica of the autograd funtion in
6+
# test_utility_funs.py (test_autograd_module_name)
7+
class CustomFunction(torch.autograd.Function):
8+
@staticmethod
9+
def forward(ctx, input):
10+
ctx.save_for_backward(input)
11+
return input.clamp(min=0)
12+
13+
@staticmethod
14+
def backward(ctx, grad_output):
15+
input, = ctx.saved_tensors
16+
grad_input = grad_output.clone()
17+
grad_input[input < 0] = 0
18+
return grad_input

test/onnx/test_utility_funs.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
_set_operator_export_type,
1414
_set_onnx_shape_inference)
1515
import torch.utils.cpp_extension
16+
from autograd_helper import CustomFunction as CustomFunction2
1617
from test_pytorch_common import (skipIfUnsupportedMinOpsetVersion,
1718
skipIfUnsupportedMaxOpsetVersion)
1819
import caffe2.python.onnx.backend as backend
@@ -977,6 +978,36 @@ def forward(self, input):
977978
iter = graph.nodes()
978979
self.assertEqual(next(iter).kind(), "prim::PythonOp")
979980

981+
def test_autograd_module_name(self):
982+
class CustomFunction(torch.autograd.Function):
983+
@staticmethod
984+
def forward(ctx, input):
985+
ctx.save_for_backward(input)
986+
return input.clamp(min=0)
987+
988+
@staticmethod
989+
def backward(ctx, grad_output):
990+
input, = ctx.saved_tensors
991+
grad_input = grad_output.clone()
992+
grad_input[input < 0] = 0
993+
return grad_input
994+
995+
class Custom(torch.nn.Module):
996+
def forward(self, input):
997+
return CustomFunction.apply(input) + CustomFunction2.apply(input)
998+
999+
model = Custom()
1000+
batch = torch.FloatTensor(1, 3)
1001+
1002+
graph, _, _ = self._model_to_graph(model, batch,
1003+
input_names=["batch"], dynamic_axes={"batch": [0, 1]})
1004+
iter = graph.nodes()
1005+
autograd1 = next(iter)
1006+
autograd2 = next(iter)
1007+
self.assertEqual(autograd1.kind(), "prim::PythonOp")
1008+
self.assertEqual(autograd2.kind(), "prim::PythonOp")
1009+
self.assertNotEqual(autograd1.s("module"), autograd2.s("module"))
1010+
9801011
def test_unused_initializers(self):
9811012
class Model(torch.nn.Module):
9821013
def __init__(self):

torch/csrc/autograd/python_function.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,11 @@ static void _trace_post_record(
563563
}
564564

565565
node->i_(jit::attr::inplace, is_inplace);
566+
if (auto module_name = PyDict_GetItemString(((PyTypeObject*)op_obj)->tp_dict, "__module__")) {
567+
if (auto ptr = PyUnicode_AsUTF8(module_name)) {
568+
node->s_(jit::attr::module, std::string(ptr));
569+
}
570+
}
566571

567572
// Isolate C variable ptrs in a vector
568573
int num_outputs = PyTuple_GET_SIZE(output_objects);

0 commit comments

Comments
 (0)