Skip to content

Commit 84b7daa

Browse files
zdevitosoumith
authored andcommitted
Relax verify of VariableFlags (#4191)
* Fix another leak in pybind11 code. This time caused by an upstream pybind11 bug: pybind/pybind11#1216 This changes causes the code to go down a non-buggy pathway. * Relax verify of VariableFlags If we trace with a defined tensor, but see a run with a undefined tensors we now allow that run to happen, replacing the tensor with zeros. This also fixes a bug where stage 0 tensors were not checked against their verify flags. This change does _not_ handle all bad situations that can happen. For instance if the first thing traced has a undefined tensor but a later tensor is defined, then it will fail because the graph itself does not contain the trace for the derivative of the tensor. However it is possible to work around this later case by dry-running the function: z = Variable(...,requires_grad=True) x,y = f(z) (x.sum() + y.sum()).backward()
1 parent fc8ad6f commit 84b7daa

6 files changed

+82
-15
lines changed

torch/csrc/jit/interpreter.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,13 @@ struct InterpreterStateImpl {
478478
outputs.clear();
479479
loadTensorsFromRegisters(stage.outputs, outputs);
480480
}
481+
const TensorType & tensorTypeForInput(size_t i) const {
482+
size_t graph_i = i;
483+
for(size_t s = 0; s < current_stage; s++)
484+
graph_i += function->stages[s].inputs.size;
485+
JIT_ASSERTM(graph_i < function->graph->inputs().size(), "Input out of range");
486+
return *function->graph->inputs().at(graph_i)->type()->expect<TensorType>();
487+
}
481488
int get(const ListHandle<int> & list, int i) {
482489
return int_data[list.start + i];
483490
};
@@ -532,6 +539,9 @@ void InterpreterState::runOneStage(
532539
std::vector<at::Tensor> & outputs) {
533540
return pImpl->runOneStage(inputs, outputs);
534541
}
542+
const TensorType & InterpreterState::tensorTypeForInput(size_t i) const {
543+
return pImpl->tensorTypeForInput(i);
544+
}
535545
InterpreterState InterpreterState::clone() const {
536546
return InterpreterState(new InterpreterStateImpl(*pImpl));
537547
}

torch/csrc/jit/interpreter.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@ namespace at {
66
struct Tensor;
77
}
88
namespace torch { namespace jit {
9-
9+
1010
// The interpreter run Graphs with Tensor inputs and Tensor outputs
1111
// a separate component in the autograd handles unwrapping and wrapping
1212
// variable objects for use in the interpreter.
1313

1414
struct CodeImpl;
1515
struct InterpreterStateImpl;
1616
struct Graph;
17+
struct TensorType;
1718

1819
struct Code {
1920
Code()
@@ -36,6 +37,7 @@ struct InterpreterState {
3637
void runOneStage(
3738
const std::vector<at::Tensor> & inputs,
3839
std::vector<at::Tensor> & outputs);
40+
const TensorType & tensorTypeForInput(size_t i) const;
3941
~InterpreterState();
4042
// create a copy of InterpreterState with its current state
4143
// used when retain_graph=True so that stages can be re-run

torch/csrc/jit/interpreter_autograd_function.cpp

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@ namespace torch { namespace jit {
55

66
using namespace torch::jit::tracer;
77

8+
static at::Tensor zeroTensorWithType(const TensorType & type) {
9+
auto device = (type.device() < 0)? at::kCPU : at::kCUDA;
10+
auto & at_type = at::getType(device, type.scalarType());
11+
// note: this has to be a contiguous tensor of zeros, because the fusion engine
12+
// specialized to what is normally here which might be fully dense
13+
return at_type.zeros(type.sizes());
14+
}
15+
816
autograd::variable_list InterpreterAutogradFunction::apply(
917
const autograd::variable_list& inputs) {
1018
// Initial correctness checks.
@@ -19,15 +27,31 @@ autograd::variable_list InterpreterAutogradFunction::apply(
1927
const auto & details = stage_details_[stage_];
2028

2129
// Validate inputs
22-
for (std::size_t i = 0; i < (std::size_t)num_inputs; ++i) {
23-
if (!details.input_flags[i].verify(inputs[i])) {
24-
throw std::runtime_error("JIT interpreter received inputs with different "
25-
"flags than it was compiled for.");
30+
std::vector<at::Tensor> tinputs;
31+
tinputs.reserve(inputs.size());
32+
TORCH_ASSERT(inputs.size() == num_inputs);
33+
TORCH_ASSERT(inputs.size() == details.input_flags.size());
34+
for (std::size_t i = 0; i < (std::size_t)inputs.size(); ++i) {
35+
if(stage_ > 0 && !inputs[i].defined() && !details.input_flags[i].was_null) {
36+
// [Temporary workaround for variants] until tracer produces all variants:
37+
// if you have a function x, y = fn(z) and only use x then gradient for y
38+
// will be undefined. If you reuse the same trace with and _sometimes_ use y
39+
// then in the cases where you don't use it, the grad_y input in stage 1
40+
// will be undefined. To ensure we can continue, we create a 0 gradient,
41+
// using trace information to figure out what shape it should be
42+
tinputs.push_back(zeroTensorWithType(interp_.tensorTypeForInput(i)));
43+
} else if(!details.input_flags[i].verify(inputs[i])) {
44+
std::stringstream ss;
45+
ss << "JIT interpreter received inputs with different "
46+
<< "flags than it was compiled for. Compiled with " << details.input_flags[i]
47+
<< " but found " << VariableFlags::of(inputs[i]) << "\n";
48+
throw std::runtime_error(ss.str());
49+
} else {
50+
tinputs.push_back(inputs[i].data());
2651
}
2752
}
2853

2954
// Run the interpreter
30-
auto tinputs = fmap(inputs, [](const autograd::Variable& i) { return i.data(); });
3155
std::vector<at::Tensor> toutputs;
3256
InterpreterState interp = (keep_graph_) ? interp_.clone() : interp_;
3357
interp.runOneStage(tinputs, toutputs);
@@ -57,7 +81,13 @@ autograd::variable_list InterpreterAutogradFunction::apply(
5781
}
5882
// Add grad_fns corresponding to inputs
5983
for (auto & input : inputs) {
60-
if (!input.requires_grad()) continue; // See Note [Null-edge pruning]
84+
if (!input.requires_grad()) {
85+
continue; // See Note [Null-edge pruning]
86+
} else if (!input.defined()) {
87+
// See Note [Temporary workaround for variants]
88+
grad_fn->next_functions.emplace_back();
89+
continue;
90+
}
6191
grad_fn->next_functions.emplace_back(
6292
input.grad_fn() ? input.grad_fn() : input.grad_accumulator(),
6393
input.output_nr());

torch/csrc/jit/interpreter_autograd_function.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@ struct InterpreterAutogradFunction : public autograd::Function {
2020
const std::vector<StageDetails>& stage_details)
2121
: interp_(code)
2222
, stage_details_(stage_details)
23-
, stage_(0) {}
23+
, stage_(0) {
24+
// stage 0 isn't run through the autograd, so we set this
25+
// here just in case it is used
26+
num_inputs = stage_details.at(0).input_flags.size();
27+
}
2428

2529
InterpreterAutogradFunction(InterpreterState interp,
2630
const std::vector<StageDetails>& stage_details,

torch/csrc/jit/python_compiled_function.cpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -200,18 +200,32 @@ CompiledFunction::TraceForKey* getTraceFor(CompiledFunction& fn,
200200

201201
} // anonymous namespace
202202

203+
static py::tuple tuple_tail(const py::tuple & tup) {
204+
py::tuple r(tup.size() - 1);
205+
for(int i = 1; i < tup.size(); i++) {
206+
r[i-1] = tup[i];
207+
}
208+
return r;
209+
}
210+
203211
void initCompilerMixin(PyObject *module) {
204212
auto m = py::handle(module).cast<py::module>();
205213
py::class_<CompiledFunction>(m, "CompiledFunction", py::dynamic_attr())
206214
.def(py::init<int, bool, bool, py::object, std::string>())
207-
.def("__call__", [](CompiledFunction& fn, py::args args) -> py::object {
208-
return fn.call(args);
215+
.def("__call__", [](py::args args_) -> py::object {
216+
auto fn = py::cast<CompiledFunction*>(args_[0]);
217+
auto args = tuple_tail(args_);
218+
return fn->call(args);
209219
})
210-
.def("has_trace_for", [](CompiledFunction& fn, py::args args) -> bool {
211-
return getTraceFor(fn, args) != nullptr;
220+
.def("has_trace_for", [](py::args args_) -> bool {
221+
auto fn = py::cast<CompiledFunction*>(args_[0]);
222+
auto args = tuple_tail(args_);
223+
return getTraceFor(*fn, args) != nullptr;
212224
})
213-
.def("graph_for", [](CompiledFunction& fn, py::args args) -> py::object {
214-
auto trace = getTraceFor(fn, args);
225+
.def("graph_for", [](py::args args_) -> py::object {
226+
auto fn = py::cast<CompiledFunction*>(args_[0]);
227+
auto args = tuple_tail(args_);
228+
auto trace = getTraceFor(*fn, args);
215229
return trace ? py::cast(trace->graph_) : py::none();
216230
})
217231
.def("clear_cache", [](CompiledFunction& fn) {

torch/csrc/jit/variable_flags.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#pragma once
2-
2+
#include <iostream>
33
namespace torch { namespace autograd {
44
struct Variable;
55
}}
@@ -15,4 +15,11 @@ struct VariableFlags {
1515
bool was_null;
1616
};
1717

18+
static inline std::ostream & operator<<(std::ostream & out, const VariableFlags& v) {
19+
return out
20+
<< "(requires_grad=" << v.requires_grad
21+
<< ", is_volatile=" << v.is_volatile
22+
<< ", was_null=" << v.was_null << ")";
23+
}
24+
1825
}}

0 commit comments

Comments
 (0)