Skip to content

Commit 07bb864

Browse files
committed
Fix another leak in pybind11 code.
This time caused by an upstream pybind11 bug: pybind/pybind11#1216 This changes causes the code to go down a non-buggy pathway.
1 parent c6381c6 commit 07bb864

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

torch/csrc/jit/python_compiled_function.cpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -200,18 +200,32 @@ CompiledFunction::TraceForKey* getTraceFor(CompiledFunction& fn,
200200

201201
} // anonymous namespace
202202

203+
static py::tuple tuple_tail(const py::tuple & tup) {
204+
py::tuple r(tup.size() - 1);
205+
for(int i = 1; i < tup.size(); i++) {
206+
r[i-1] = tup[i];
207+
}
208+
return r;
209+
}
210+
203211
void initCompilerMixin(PyObject *module) {
204212
auto m = py::handle(module).cast<py::module>();
205213
py::class_<CompiledFunction>(m, "CompiledFunction", py::dynamic_attr())
206214
.def(py::init<int, bool, bool, py::object, std::string>())
207-
.def("__call__", [](CompiledFunction& fn, py::args args) -> py::object {
208-
return fn.call(args);
215+
.def("__call__", [](py::args args_) -> py::object {
216+
auto fn = py::cast<CompiledFunction*>(args_[0]);
217+
auto args = tuple_tail(args_);
218+
return fn->call(args);
209219
})
210-
.def("has_trace_for", [](CompiledFunction& fn, py::args args) -> bool {
211-
return getTraceFor(fn, args) != nullptr;
220+
.def("has_trace_for", [](py::args args_) -> bool {
221+
auto fn = py::cast<CompiledFunction*>(args_[0]);
222+
auto args = tuple_tail(args_);
223+
return getTraceFor(*fn, args) != nullptr;
212224
})
213-
.def("graph_for", [](CompiledFunction& fn, py::args args) -> py::object {
214-
auto trace = getTraceFor(fn, args);
225+
.def("graph_for", [](py::args args_) -> py::object {
226+
auto fn = py::cast<CompiledFunction*>(args_[0]);
227+
auto args = tuple_tail(args_);
228+
auto trace = getTraceFor(*fn, args);
215229
return trace ? py::cast(trace->graph_) : py::none();
216230
})
217231
.def("clear_cache", [](CompiledFunction& fn) {

0 commit comments

Comments
 (0)