Skip to content

Commit b456924

Browse files
apaszkeRob Kunkle
authored and
Rob Kunkle
committed
Slightly relax the constraints on argument and return types to script functions (pytorch#9969)
Summary: This lays out initial support for taking and returning a richer set of types than only tensors. Floats and ints are already valid, lists are straightforward to add, tuples need some discussion. Based on top of pytorch#9948. Review only the last commit. zdevito Pull Request resolved: pytorch#9969 Reviewed By: zdevito Differential Revision: D9076973 Pulled By: apaszke fbshipit-source-id: 5a1fe912ea6b79ab2bfd0dcce265eb05855b5ff0
1 parent 8b9767a commit b456924

File tree

8 files changed

+97
-35
lines changed

8 files changed

+97
-35
lines changed

test/expect/TestJit.test_constant_prop_nested.expect

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,5 @@ graph(%a : Dynamic) {
1111
%6 : int = prim::Constant[value=1]()
1212
-> (%6)
1313
}
14-
%7 : Long() = prim::NumToTensor(%c)
15-
return (%7);
14+
return (%c);
1615
}

test/test_jit.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3717,10 +3717,10 @@ def test_unknown_builtin(self):
37173717
def unknown_builtin(x):
37183718
return x.splork(3)
37193719

3720-
def test_expected_tensor_found_tuple(self):
3721-
with self.assertRaisesRegex(RuntimeError, 'expected a tensor value but found'):
3720+
def test_return_tuple(self):
3721+
with self.assertRaisesRegex(RuntimeError, 'only supported return types'):
37223722
@torch.jit.script
3723-
def return_tuple_wrong(x):
3723+
def return_tuple(x):
37243724
a = (x, x)
37253725
return a, x
37263726

@@ -4439,6 +4439,17 @@ def tuple_arg(x):
44394439
# type: (Tuple[Tensor, Tensor]) -> Tensor
44404440
return x + 1
44414441

4442+
def test_script_non_tensor_args_outputs(self):
4443+
@torch.jit.script
4444+
def fn(x, y):
4445+
# type: (Tensor, float) -> float
4446+
return float((x + y).sum())
4447+
4448+
x = torch.ones(2, 2)
4449+
z = fn(x, 1)
4450+
self.assertIsInstance(z, float)
4451+
self.assertEqual(z, 8.)
4452+
44424453
@unittest.skip('https://github.com/pytorch/pytorch/issues/9595')
44434454
def test_inline_and_run_annotated_script_fn(self):
44444455
@torch.jit.script

torch/csrc/jit/graph_executor.cpp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -241,14 +241,7 @@ struct GraphExecutorImpl {
241241
, symbolically_differentiable(symbolically_differentiable)
242242
, may_introduce_gradient(calcMayIntroduceGradient(this->graph->block())) {}
243243
GraphExecutorImpl(std::shared_ptr<Graph> graph, bool optimize)
244-
: GraphExecutorImpl(graph, optimize, isDifferentiable(*graph)) {
245-
for(auto input : graph->inputs()) {
246-
JIT_ASSERTM(input->type()->kind() != TypeKind::TupleType, "tuples cannot be inputs to the graph");
247-
}
248-
for(auto output : graph->outputs()) {
249-
JIT_ASSERTM(output->type()->kind() != TypeKind::TupleType, "tuples cannot be outputs to the graph");
250-
}
251-
}
244+
: GraphExecutorImpl(graph, optimize, isDifferentiable(*graph)) {}
252245

253246
// entry point where execution begins
254247
void run(Stack & stack) {

torch/csrc/jit/init.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ void initJITBindings(PyObject *module) {
7171
})
7272
.def("_jit_pass_lint", LintGraph)
7373
.def("_jit_pass_shape_analysis", [](Graph& graph, py::tuple inputs, bool with_grad) {
74-
PropagateInputShapes(graph, ArgumentSpec(with_grad, createStack(inputs)));
74+
PropagateInputShapes(graph, ArgumentSpec(with_grad, createStack(inputs, graph.inputs())));
7575
})
7676
.def("_jit_pass_remove_expands", RemoveExpands)
7777
.def("_jit_pass_erase_number_types", EraseNumberTypes)
@@ -186,15 +186,16 @@ void initJITBindings(PyObject *module) {
186186
return ge.graph();
187187
})
188188
.def("graph_for", [](GraphExecutor& ge, py::args args) {
189-
return ge.graphFor(createStack(args));
189+
return ge.graphFor(createStack(args, ge.graph()->inputs()));
190190
})
191191
.def("get_debug_state", [](GraphExecutor& ge) {
192192
return ge.getDebugState();
193193
})
194194
.def("__call__", [](GraphExecutor& ge, py::args args) -> py::object {
195-
auto stack = createStack(args);
195+
const auto & graph = ge.graph();
196+
auto stack = createStack(args, graph->inputs());
196197
ge.run(stack);
197-
return wrapStack(std::move(stack));
198+
return wrapStack(std::move(stack), graph->outputs());
198199
});
199200

200201

torch/csrc/jit/pybind_utils.h

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,70 @@
44

55
namespace torch { namespace jit {
66

7-
inline Stack createStack(const py::tuple& tuple, size_t reserve_extra_space = 0) {
7+
inline Stack createStack(const py::tuple& tuple, at::ArrayRef<Value*> inputs, size_t reserve_extra_space = 0) {
8+
if (tuple.size() != inputs.size()) {
9+
throw std::runtime_error("expected " + std::to_string(inputs.size()) +
10+
" inputs, but got " + std::to_string(tuple.size()));
11+
}
12+
static const auto castToIValue = [](const py::object& obj, Type& t) -> IValue{
13+
switch (t.kind()) {
14+
case TypeKind::DynamicType:
15+
case TypeKind::TensorType:
16+
return py::cast<autograd::Variable>(obj);
17+
case TypeKind::FloatType:
18+
return py::cast<double>(obj);
19+
case TypeKind::IntType:
20+
return py::cast<int64_t>(obj);
21+
case TypeKind::NoneType:
22+
return {};
23+
case TypeKind::ListType:
24+
case TypeKind::TupleType:
25+
throw std::runtime_error("Lists and tuples are not supported yet");
26+
case TypeKind::NumberType:
27+
throw std::runtime_error("Insufficient type information to convert input");
28+
}
29+
throw std::runtime_error("Missing cases in castToIValue! File a bug report.");
30+
};
831
Stack result;
932
result.reserve(tuple.size() + reserve_extra_space);
10-
for(auto e : tuple) {
11-
result.push_back(py::cast<autograd::Variable>(e));
33+
for (size_t i = 0; i < inputs.size(); ++i) {
34+
result.push_back(castToIValue(tuple[i], *inputs[i]->type()));
1235
}
1336
return result;
1437
}
1538

16-
inline py::object wrapStack(Stack&& outputs) {
39+
inline py::object wrapStack(Stack&& outputs, at::ArrayRef<Value*> output_vals) {
40+
if (outputs.size() != output_vals.size()) {
41+
throw std::runtime_error("expected " + std::to_string(output_vals.size()) +
42+
" outputs, but got " + std::to_string(outputs.size()));
43+
}
44+
static const auto createOutput = [](IValue && ivalue, Value * value) -> py::object {
45+
switch (value->type()->kind()) {
46+
case TypeKind::DynamicType:
47+
case TypeKind::TensorType:
48+
return py::cast(autograd::Variable(ivalue.toTensor()));
49+
case TypeKind::FloatType:
50+
return py::cast(ivalue.toDouble());
51+
case TypeKind::IntType:
52+
return py::cast(ivalue.toInt());
53+
case TypeKind::NoneType:
54+
return py::none();
55+
case TypeKind::ListType:
56+
case TypeKind::TupleType:
57+
throw std::runtime_error("Lists and tuples are not supported yet");
58+
case TypeKind::NumberType:
59+
throw std::runtime_error("Insufficient type information to convert input");
60+
}
61+
throw std::runtime_error("Missing cases in createOutput! File a bug report.");
62+
};
1763
if (outputs.size() == 0) {
1864
return py::none();
1965
} else if (outputs.size() == 1) {
20-
JIT_ASSERT(outputs[0].isTensor());
21-
return py::cast(autograd::as_variable_ref(std::move(outputs[0]).toTensor()));
66+
return createOutput(std::move(outputs[0]), output_vals[0]);
2267
} else {
2368
py::tuple tuple(outputs.size());
2469
for(size_t i = 0; i < outputs.size(); i++) {
25-
JIT_ASSERT(outputs[i].isTensor());
26-
tuple[i] = py::cast(autograd::as_variable_ref(std::move(outputs[i]).toTensor()));
70+
tuple[i] = createOutput(std::move(outputs[i]), output_vals[i]);
2771
}
2872
return tuple;
2973
}

torch/csrc/jit/script/compiler.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -685,12 +685,16 @@ struct to_ir {
685685
auto range = return_stmt.range();
686686
size_t return_type_idx = 0;
687687
for (auto& r : results) {
688-
if(r->type()->isSubtypeOf(NumberType::get())) {
689-
graph->registerOutput(numToTensor(range, r));
690-
} else {
691-
ensureTensor(range, r);
692-
graph->registerOutput(r);
688+
// TODO: support tuples and lists as returns
689+
auto return_kind = r->type()->kind();
690+
if (return_kind != TypeKind::TensorType &&
691+
return_kind != TypeKind::DynamicType &&
692+
return_kind != TypeKind::IntType &&
693+
return_kind != TypeKind::FloatType) {
694+
throw ErrorReport(return_stmt.range()) << "The only supported return types "
695+
<< "are tensors, ints and floats";
693696
}
697+
graph->registerOutput(r);
694698
TypePtr type = DynamicType::get();
695699
if (typed_def.schema) {
696700
type = typed_def.schema->returns.at(return_type_idx).type;

torch/csrc/jit/script/init.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -370,10 +370,15 @@ static void gatherParametersAndBuffers(std::vector<at::Tensor*> & values, const
370370
}
371371
}
372372

373+
Stack createStack(const py::tuple& tuple, const Method& method) {
374+
auto relevant_inputs = method.graph()->inputs().slice(0, method.num_inputs());
375+
return createStack(tuple, relevant_inputs);
376+
}
377+
373378
py::object runMethodFromPython(Method& m, py::args args) {
374-
auto stack = createStack(args);
379+
auto stack = createStack(args, m);
375380
m.run(stack);
376-
return wrapStack(std::move(stack));
381+
return wrapStack(std::move(stack), m.graph()->outputs());
377382
}
378383

379384
void initJitScriptBindings(PyObject* module) {
@@ -502,7 +507,8 @@ void initJitScriptBindings(PyObject* module) {
502507
})
503508
.def("graph_for", [](Module& self, py::args args) {
504509
if (self.find_method("forward")) {
505-
return self.get_method("forward").graph_for(createStack(args));
510+
Method & m = self.get_method("forward");
511+
return m.graph_for(createStack(args, m.graph()->inputs()));
506512
}
507513
throw std::runtime_error("Attempted to call graph_for on a Module without a compiled forward()");
508514
})
@@ -530,7 +536,7 @@ void initJitScriptBindings(PyObject* module) {
530536
.def("propagate_and_assign_input_and_output_shapes", &Method::propagate_and_assign_input_and_output_shapes)
531537
.def("params", &Method::params)
532538
.def("graph_for", [](Method& self, py::args args) {
533-
return self.graph_for(createStack(args));
539+
return self.graph_for(createStack(args, self.graph()->inputs()));
534540
})
535541
.def("set_arg_and_return_types", [](Method &self, TypedDef &typed_def, bool method) {
536542
std::vector<Argument> arg_type_args, return_type_args;

torch/jit/annotations.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import ast
44
import inspect
55
import torch
6-
from torch._C import DynamicType, TupleType
6+
from torch._C import DynamicType, TupleType, FloatType, IntType
77
from textwrap import dedent
88

99

@@ -209,4 +209,8 @@ def ann_to_type(ann):
209209
return DynamicType.get()
210210
elif is_tuple(ann):
211211
return TupleType([ann_to_type(a) for a in ann.__args__])
212+
elif ann is float:
213+
return FloatType.get()
214+
elif ann is int:
215+
return IntType.get()
212216
raise ValueError("The only supported annotations kinds are Tensor and Tuple[...]")

0 commit comments

Comments
 (0)