Skip to content

Commit 985c91d

Browse files
apaszkePenghuiCheng
authored andcommitted
Improve support for tracing sizes, add more tracer warnings (pytorch#11288)
Summary: Many constructors like `torch.zeros` or `torch.randn` didn't support size tracing correctly which is fixed by this pass. Same issue has been fixed in legacy tensor constructors. Additionally, new tensor constructors, which do not participate in tracing (most notably `torch.tensor`, `torch.as_tensor` and `torch.from_numpy`) raise a warning when they are used. Finally, entering a traceable operation disables the tracing in its body. This is needed because zdevito Pull Request resolved: pytorch#11288 Reviewed By: ezyang Differential Revision: D9751183 Pulled By: apaszke fbshipit-source-id: 51444a39d76a3e164adc396c432fd5ee3c8d5f7f
1 parent a79cf6c commit 985c91d

12 files changed

+61
-41
lines changed

test/expect/TestScript.test_index_put_trace_with_view.expect

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ graph(%0 : Double(100)
77
%6 : int = prim::Constant[value=0]()
88
%7 : Long(4) = aten::_cast_Long(%1, %6)
99
%8 : Dynamic[] = prim::ListConstruct(%7)
10-
%20 : Double(100) = aten::index_put(%0, %8, %5)
11-
return (%20);
10+
%9 : Double(100) = aten::index_put(%0, %8, %5)
11+
return (%9);
1212
}

test/expect/TestScript.test_index_put_trace_without_view.expect

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ graph(%0 : Double(100)
44
%3 : int = prim::Constant[value=0]()
55
%4 : Long(4) = aten::_cast_Long(%1, %3)
66
%5 : Dynamic[] = prim::ListConstruct(%4)
7-
%17 : Double(100) = aten::index_put(%0, %5, %2)
8-
return (%17);
7+
%6 : Double(100) = aten::index_put(%0, %5, %2)
8+
return (%6);
99
}

test/onnx/test_pytorch_onnx_caffe2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -346,11 +346,11 @@ def test_rnn_init_predict_split(self):
346346
mp = onnx.ModelProto.FromString(do_export(model, input, export_params=self.embed_params)[0])
347347
prepared = c2.prepare(mp, device='CPU')
348348
if self.embed_params:
349-
assert len(prepared.init_net.op) == 1038
350-
assert len(prepared.predict_net.op) == 101
349+
assert len(prepared.init_net.op) == 1019
350+
assert len(prepared.predict_net.op) == 142
351351
else:
352-
assert len(prepared.init_net.op) == 27
353-
assert len(prepared.predict_net.op) == 1112
352+
assert len(prepared.init_net.op) == 8
353+
assert len(prepared.predict_net.op) == 1153
354354

355355
def test_alexnet(self):
356356
state_dict = model_zoo.load_url(model_urls['alexnet'], progress=False)

test/test_jit.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,7 @@ def f(x, y):
916916

917917
torch.jit.verify(f, (x, y), loss_fn=lambda z, w: z * w, devices=[])
918918

919+
@suppress_warnings
919920
def test_constant(self):
920921
x = torch.randn(2, 2, requires_grad=True)
921922

@@ -6197,15 +6198,15 @@ def test_index_put(target, indices, rhs):
61976198
target[indices] = rhs
61986199
return target
61996200

6200-
self.assertExpected(str(test_index_put.graph))
6201+
self.assertExpectedGraph(test_index_put.graph)
62016202

62026203
def test_index_put_trace_without_view(self):
62036204
@_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(4))
62046205
def test_index_put(target, indices, rhs):
62056206
target[indices] = rhs
62066207
return target
62076208

6208-
self.assertExpected(str(test_index_put.graph))
6209+
self.assertExpectedGraph(test_index_put.graph)
62096210

62106211
def test_annotated_script_fn(self):
62116212
@torch.jit.script
@@ -6895,6 +6896,7 @@ def forward(self, x):
68956896
net = Net(upscale_factor=4)
68966897
self.checkTrace(net, (torch.rand(5, 1, 64, 64),))
68976898

6899+
@suppress_warnings
68986900
def test_time_sequence_prediction(self):
68996901
class Sequence(torch.jit.ScriptModule):
69006902
def __init__(self):

tools/autograd/gen_variable_type.py

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,15 @@
128128

129129
PRE_RECORD_TRACE = CodeTemplate("""\
130130
torch::jit::Node* node = nullptr;
131+
std::shared_ptr<jit::tracer::TracingState> tracer_state;
131132
if (jit::tracer::isTracing()) {
132-
auto& graph = jit::tracer::getTracingState()->graph;
133-
node = graph->create(jit::aten::${trace_name}, /*outputs=*/0);
133+
tracer_state = jit::tracer::getTracingState();
134+
node = tracer_state->graph->create(jit::aten::${trace_name}, /*outputs=*/0);
134135
jit::tracer::recordSourceLocation(node);
135136
${add_trace_inputs}
136-
graph->appendNode(node);
137+
tracer_state->graph->appendNode(node);
137138
${inplace_guard}
139+
jit::tracer::setTracingState(nullptr);
138140
}
139141
""")
140142

@@ -145,35 +147,20 @@
145147
ADD_TRACE_INPUT = CodeTemplate("""jit::tracer::addInputs(node, "${input}", ${input});""")
146148

147149
POST_RECORD_TRACE = CodeTemplate("""\
148-
if (jit::tracer::isTracing()) {
150+
if (tracer_state) {
151+
jit::tracer::setTracingState(std::move(tracer_state));
149152
${record_trace_outputs}
150153
}
151154
""")
152155

153-
RECORD_ATTRIBUTE = CodeTemplate("""\
154-
setattr(trace_info.n, jit::attr::${attr_name}, ${name});""")
155-
156-
RECORD_POSITIONAL_ATTRIBUTE = CodeTemplate("""\
157-
setposattr(trace_info.n, ${i}, "${name}", ${name});""")
158-
159-
POSITIONAL_ATTR_NYI = """\
160-
throw std::runtime_error("Can't have size-dependent arguments to functions that "
161-
"take variable number of tensor arguments");
162-
"""
163-
164156

165157
def should_trace(declaration):
166-
# Operations involving Generator, Storage, Type are not traceable
167-
# at the moment
168-
if any(arg['simple_type'] in {'Generator', 'Storage', 'ScalarType', 'Type', 'optional<ScalarType>'}
169-
for arg in declaration['arguments']):
158+
# Operations involving Storage or Type are not traceable at the moment
159+
if any(arg['simple_type'] in {'Storage', 'Type'} for arg in declaration['arguments']):
170160
return False
171161
# We can't trace functions which don't have any Tensor or TensorList returns
172162
if 'Tensor' not in declaration['return_type']:
173163
return False
174-
tensor_args = [arg for arg in declaration['arguments'] if arg['simple_type'] in {'Tensor', 'TensorList'}]
175-
if len(tensor_args) == 0:
176-
return False
177164
name = declaration['name']
178165
base_name = name[:-1] if declaration['inplace'] else name[:-4] if name.endswith('_out') else name
179166
if base_name in DONT_RECORD_TRACE:

tools/autograd/templates/python_torch_functions.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "torch/csrc/utils/tensor_layouts.h"
2020
#include "torch/csrc/utils/tensor_new.h"
2121
#include "torch/csrc/utils/tensor_numpy.h"
22+
#include "torch/csrc/jit/tracer.h"
2223
#include "torch/csrc/autograd/generated/variable_factories.h"
2324

2425
#include <ATen/ATen.h>
@@ -320,13 +321,15 @@ static PyObject * THPVariable_randint(PyObject* self_, PyObject* args, PyObject*
320321
static PyObject * THPVariable_as_tensor(PyObject* self, PyObject* args, PyObject* kwargs)
321322
{
322323
HANDLE_TH_ERRORS
324+
jit::tracer::warn("torch.as_tensor");
323325
return THPVariable_Wrap(torch::utils::as_tensor(default_type(), args, kwargs));
324326
END_HANDLE_TH_ERRORS
325327
}
326328

327329
static PyObject * THPVariable_from_numpy(PyObject* module, PyObject* arg)
328330
{
329331
HANDLE_TH_ERRORS
332+
jit::tracer::warn("torch.from_numpy");
330333
auto data = torch::utils::tensor_from_numpy(arg);
331334
return THPVariable_Wrap(make_variable(std::move(data), /*requires_grad=*/false));
332335
END_HANDLE_TH_ERRORS
@@ -351,13 +354,15 @@ static PyObject * THPVariable__promote_types(PyObject* self, PyObject* args, PyO
351354
static PyObject * THPVariable_sparse_coo_tensor(PyObject* self, PyObject* args, PyObject* kwargs)
352355
{
353356
HANDLE_TH_ERRORS
357+
jit::tracer::warn("torch.sparse_coo_tensor");
354358
return THPVariable_Wrap(torch::utils::sparse_coo_tensor_ctor(default_type(), args, kwargs));
355359
END_HANDLE_TH_ERRORS
356360
}
357361

358362
static PyObject * THPVariable_tensor(PyObject* self, PyObject* args, PyObject* kwargs)
359363
{
360364
HANDLE_TH_ERRORS
365+
jit::tracer::warn("torch.tensor");
361366
return THPVariable_Wrap(torch::utils::tensor_ctor(default_type(), args, kwargs));
362367
END_HANDLE_TH_ERRORS
363368
}

torch/csrc/autograd/python_variable.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "torch/csrc/utils/python_strings.h"
2323
#include "torch/csrc/utils/python_arg_parser.h"
2424
#include "torch/csrc/utils/tensor_new.h"
25+
#include "torch/csrc/jit/tracer.h"
2526

2627
#include <ATen/ATen.h>
2728

@@ -125,6 +126,7 @@ static void THPVariable_dealloc(THPVariable* self)
125126
static PyObject *THPVariable_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs)
126127
{
127128
HANDLE_TH_ERRORS
129+
jit::tracer::warn("torch.Tensor");
128130
auto& default_type = torch::tensors::get_default_tensor_type();
129131
auto tensor = torch::utils::legacy_tensor_ctor(default_type, args, kwargs);
130132
return THPVariable_NewWithVar(type, std::move(tensor));

torch/csrc/jit/tracer.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@ void genericAddInput(Node *n, T value) {
2626
n->addInput(v);
2727
}
2828

29-
void badArgType() {
30-
AT_ERROR("Found an unsupported argument type in the JIT tracer. File a bug report.");
29+
template<typename T>
30+
void badArgType(const T& v) {
31+
AT_ERROR("Found an unsupported argument type in the JIT tracer: ", at::demangle_type<T>(), ". File a bug report.");
3132
}
3233

3334
thread_local std::shared_ptr<TracingState> tracing_state;
@@ -39,8 +40,10 @@ void addInputs(Node *n, const char * name, bool value) { detail::g
3940
void addInputs(Node *n, const char * name, double value) { detail::genericAddInput(n, value); }
4041
void addInputs(Node *n, const char * name, const at::Scalar& value) { detail::genericAddInput(n, value); }
4142
void addInputs(Node *n, const char * name, const at::Tensor& value) { n->addInput(getValueTrace(value)); }
42-
void addInputs(Node *n, const char * name, const std::string& value) { detail::badArgType(); }
43-
void addInputs(Node *n, const char * name, const at::SparseTensorRef& value) { detail::badArgType(); }
43+
void addInputs(Node *n, const char * name, const std::string& value) { detail::badArgType(value); }
44+
void addInputs(Node *n, const char * name, const at::SparseTensorRef& value) { detail::badArgType(value); }
45+
void addInputs(Node *n, const char * name, at::Generator * value) { detail::badArgType(value); }
46+
void addInputs(Node *n, const char * name, at::ScalarType value) { detail::badArgType(value); }
4447

4548
void addInputs(Node *n, const char * name, at::TensorList value) {
4649
Graph *g = n->owningGraph();

torch/csrc/jit/tracer.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ TORCH_API void addInputs(Node *n, const char * name, const ArrayRef<double>& val
172172
TORCH_API void addInputs(Node *n, const char * name, const std::string& value);
173173
TORCH_API void addInputs(Node *n, const char * name, const at::SparseTensorRef& value);
174174
TORCH_API void addInputs(Node *n, const char * name, const at::TensorOptions& value);
175+
TORCH_API void addInputs(Node *n, const char * name, at::Generator * value);
176+
TORCH_API void addInputs(Node *n, const char * name, at::ScalarType value);
175177

176178
template<size_t N>
177179
void addInputs(Node *n, const char * name, std::array<bool, N> value) {

torch/csrc/utils/python_arg_parser.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ inline std::vector<int64_t> PythonArgs::intlistWithDefault(int i, std::vector<in
292292
try {
293293
// Elements of torch.Size are tensors during tracing, and we need to record extra
294294
// information before they are turned into an IntList
295-
if (traceable && THPVariable_Check(obj)) {
295+
if (traceable && jit::tracer::isTracing() && THPVariable_Check(obj)) {
296296
auto & var = THPVariable_Unpack(obj);
297297
jit::tracer::ArgumentStash::stashIntListElem(
298298
signature.params[i].name, size, idx, var);

torch/csrc/utils/tensor_new.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ Tensor new_tensor(const Type& type, PyObject* args, PyObject* kwargs) {
511511
Tensor new_empty(const Type& type, PyObject* args, PyObject* kwargs) {
512512
static PythonArgParser parser({
513513
"new_empty(IntList size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
514-
});
514+
}, /*traceable=*/true);
515515

516516
ParsedArgs<4> parsed_args;
517517
auto r = parser.parse(args, kwargs, parsed_args);
@@ -525,7 +525,7 @@ Tensor new_empty(const Type& type, PyObject* args, PyObject* kwargs) {
525525
Tensor new_full(const Type& type, PyObject* args, PyObject* kwargs) {
526526
static PythonArgParser parser({
527527
"new_full(IntList size, Scalar fill_value, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
528-
});
528+
}, /*traceable=*/true);
529529

530530
ParsedArgs<5> parsed_args;
531531
auto r = parser.parse(args, kwargs, parsed_args);
@@ -539,7 +539,7 @@ Tensor new_full(const Type& type, PyObject* args, PyObject* kwargs) {
539539
Tensor new_ones(const Type& type, PyObject* args, PyObject* kwargs) {
540540
static PythonArgParser parser({
541541
"new_ones(IntList size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
542-
});
542+
}, /*traceable=*/true);
543543

544544
ParsedArgs<4> parsed_args;
545545
auto r = parser.parse(args, kwargs, parsed_args);
@@ -553,7 +553,7 @@ Tensor new_ones(const Type& type, PyObject* args, PyObject* kwargs) {
553553
Tensor new_zeros(const Type& type, PyObject* args, PyObject* kwargs) {
554554
static PythonArgParser parser({
555555
"new_zeros(IntList size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
556-
});
556+
}, /*traceable=*/true);
557557

558558
ParsedArgs<4> parsed_args;
559559
auto r = parser.parse(args, kwargs, parsed_args);

torch/onnx/symbolic.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -936,6 +936,25 @@ def zeros_like(g, input):
936936
return g.op("Sub", input, input).setType(input.type().contiguous())
937937

938938

939+
scalar_type_to_onnx = [
940+
cast_pytorch_to_onnx["Byte"],
941+
cast_pytorch_to_onnx["Char"],
942+
cast_pytorch_to_onnx["Short"],
943+
cast_pytorch_to_onnx["Int"],
944+
cast_pytorch_to_onnx["Long"],
945+
cast_pytorch_to_onnx["Half"],
946+
cast_pytorch_to_onnx["Float"],
947+
cast_pytorch_to_onnx["Double"],
948+
]
949+
950+
951+
@parse_args('v', 'i', 'i', 'v')
952+
def zeros(g, shape, scalar_type, layout, device):
953+
# NOTE: no way to set device in ONNX, so we ignore it
954+
return g.op("ConstantFill", shape, dtype_i=scalar_type_to_onnx[scalar_type],
955+
input_as_shape_i=1, value_f=0)
956+
957+
939958
def full_like(g, input, fill_value):
940959
# TODO: a more efficient implementation (ConstantFill?)
941960
return add(g, zeros_like(g, input), fill_value, g.op("Constant", value_t=torch.tensor(1)))

0 commit comments

Comments
 (0)