Skip to content

Commit e5346e2

Browse files
committed
PR comments
1 parent 2ef2bf8 commit e5346e2

File tree

4 files changed

+26
-21
lines changed

4 files changed

+26
-21
lines changed

torch/csrc/jit/script/compiler.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ struct Environment {
9090

9191
return sv;
9292
}
93-
93+
Block* block() {
94+
return b;
95+
}
9496
Symbol getBlockOwningKind() {
9597
Symbol owning_kind = Symbol();
9698
if (b->owningNode()) {
@@ -538,10 +540,18 @@ struct to_ir {
538540
// unrolled
539541
auto sv = emitSugaredExpr(itrs[0]);
540542
auto instances = sv->unrolledFor(stmt.range(), method);
543+
const std::string& target_name = targets[0].name();
544+
pushFrame(environment_stack->block());
541545
for(auto inst : instances) {
542-
environment_stack->setSugaredVar(targets[0].name(), inst);
546+
environment_stack->setSugaredVar(target_name, inst);
543547
emitStatements(body);
544548
}
549+
for (const auto & n : environment_stack->definedVariables()) {
550+
if (environment_stack->findInParentFrame(n)) {
551+
environment_stack->next->setVar(n, environment_stack->getVar(n, stmt.range()));
552+
}
553+
}
554+
popFrame();
545555
}
546556

547557
void emitWhile(const While& stmt) {

torch/csrc/jit/script/compiler.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ struct SugaredValue : public std::enable_shared_from_this<SugaredValue> {
4646
}
4747

4848
// use it in `for i in this: ...`
49-
// in this case we will unroll the loop boy by assigning i to each of
49+
// in this case we will unroll the loop body by assigning i to each of
5050
// the SugaredValues returned from this method.
5151
virtual std::vector<std::shared_ptr<SugaredValue>> unrolledFor(SourceRange loc, Method& m) {
5252
throw ErrorReport(loc) << kind() << " is not iterable";

torch/csrc/jit/script/init.cpp

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,6 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
8888
py::object self;
8989
};
9090

91-
// get the SugaredValue for something inside a torch.jit.Const
92-
// this can either be a ConstantPythonValue or a ModuleValue
93-
static std::shared_ptr<SugaredValue> createConstantSugaredValue(py::object obj);
94-
9591
// by using torch.jit.Const, a user can mark a python value constant
9692
// we then make that value immutable.
9793
// once marked constant, we enable additional behavior such as
@@ -116,7 +112,7 @@ struct VISIBILITY_HIDDEN ConstantPythonValue : public PythonValue {
116112
py::tuple tup = self;
117113
std::vector<std::shared_ptr<SugaredValue>> result;
118114
for(size_t i = 0; i < tup.size(); ++i) {
119-
result.push_back(createConstantSugaredValue(tup[i]));
115+
result.push_back(std::make_shared<ConstantPythonValue>(tup[i]));
120116
}
121117
return result;
122118
}
@@ -195,7 +191,7 @@ struct ModuleValue : public SugaredValue {
195191
py::isinstance(attr, py::module::import("torch.nn").attr("Module"))) {
196192
return std::make_shared<PythonValue>(attr);
197193
} else if(py_module.attr("_constants_set").contains(field.c_str())) {
198-
return createConstantSugaredValue(attr);
194+
return std::make_shared<ConstantPythonValue>(attr);
199195
} else {
200196
throw ErrorReport(loc) << "attribute '" << field << "' of type '" << typeString(attr) << "' is not usable in a script method (did you forget to add it __constants__?)";
201197
}
@@ -213,8 +209,13 @@ struct ModuleValue : public SugaredValue {
213209
return SugaredValue::unrolledFor(loc, m);
214210
std::vector<std::shared_ptr<SugaredValue>> result;
215211
for(py::handle module : py_module) {
216-
result.push_back(createConstantSugaredValue(
217-
py::reinterpret_borrow<py::object>(module)));
212+
py::object obj = py::reinterpret_borrow<py::object>(module);
213+
if(py::isinstance<Module>(obj)) {
214+
auto r = py::cast<std::shared_ptr<Module>>(obj);
215+
result.push_back(std::make_shared<ModuleValue>(r));
216+
} else {
217+
result.push_back(std::make_shared<ConstantPythonValue>(obj));
218+
}
218219
}
219220
return result;
220221
}
@@ -223,13 +224,6 @@ struct ModuleValue : public SugaredValue {
223224
std::shared_ptr<Module> module;
224225
};
225226

226-
static std::shared_ptr<SugaredValue> createConstantSugaredValue(py::object obj) {
227-
if(py::isinstance<Module>(obj)) {
228-
auto r = py::cast<std::shared_ptr<Module>>(obj);
229-
return std::make_shared<ModuleValue>(r);
230-
}
231-
return std::make_shared<ConstantPythonValue>(obj);
232-
}
233227

234228
// TODO: dedup with other init
235229

torch/jit/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -656,14 +656,15 @@ def __getitem__(self, k):
656656
raise KeyError(k)
657657
return self.module._get_parameter(k)
658658

659-
# base types that can be constants in addition to tuples and lists
659+
# base types that can be constants
660+
# in addition, tuples and lists of these base types are also considered constants
660661
# If you edit this list, then you also need to edit the handlers in
661662
# ConstantValue in jit/script/init.cpp
662-
_constant_types = [bool, float, int, types.FunctionType]
663+
_constant_types = (bool, float, int, types.FunctionType)
663664

664665

665666
def _get_valid_constant(v):
666-
if any(isinstance(v, typ) for typ in _constant_types):
667+
if isinstance(v, _constant_types):
667668
return v
668669
elif isinstance(v, tuple) or isinstance(v, list):
669670
return tuple(_get_valid_constant(x) for x in v)

0 commit comments

Comments
 (0)