@@ -88,10 +88,6 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
88
88
py::object self;
89
89
};
90
90
91
- // get the SugaredValue for something inside a torch.jit.Const
92
- // this can either be a ConstantPythonValue or a ModuleValue
93
- static std::shared_ptr<SugaredValue> createConstantSugaredValue (py::object obj);
94
-
95
91
// by using torch.jit.Const, a user can mark a python value constant
96
92
// we then make that value immutable.
97
93
// once marked constant, we enable additional behavior such as
@@ -116,7 +112,7 @@ struct VISIBILITY_HIDDEN ConstantPythonValue : public PythonValue {
116
112
py::tuple tup = self;
117
113
std::vector<std::shared_ptr<SugaredValue>> result;
118
114
for (size_t i = 0 ; i < tup.size (); ++i) {
119
- result.push_back (createConstantSugaredValue (tup[i]));
115
+ result.push_back (std::make_shared<ConstantPythonValue> (tup[i]));
120
116
}
121
117
return result;
122
118
}
@@ -195,7 +191,7 @@ struct ModuleValue : public SugaredValue {
195
191
py::isinstance (attr, py::module::import (" torch.nn" ).attr (" Module" ))) {
196
192
return std::make_shared<PythonValue>(attr);
197
193
} else if (py_module.attr (" _constants_set" ).contains (field.c_str ())) {
198
- return createConstantSugaredValue (attr);
194
+ return std::make_shared<ConstantPythonValue> (attr);
199
195
} else {
200
196
throw ErrorReport (loc) << " attribute '" << field << " ' of type '" << typeString (attr) << " ' is not usable in a script method (did you forget to add it __constants__?)" ;
201
197
}
@@ -213,8 +209,13 @@ struct ModuleValue : public SugaredValue {
213
209
return SugaredValue::unrolledFor (loc, m);
214
210
std::vector<std::shared_ptr<SugaredValue>> result;
215
211
for (py::handle module : py_module) {
216
- result.push_back (createConstantSugaredValue (
217
- py::reinterpret_borrow<py::object>(module)));
212
+ py::object obj = py::reinterpret_borrow<py::object>(module);
213
+ if (py::isinstance<Module>(obj)) {
214
+ auto r = py::cast<std::shared_ptr<Module>>(obj);
215
+ result.push_back (std::make_shared<ModuleValue>(r));
216
+ } else {
217
+ result.push_back (std::make_shared<ConstantPythonValue>(obj));
218
+ }
218
219
}
219
220
return result;
220
221
}
@@ -223,13 +224,6 @@ struct ModuleValue : public SugaredValue {
223
224
std::shared_ptr<Module> module;
224
225
};
225
226
226
- static std::shared_ptr<SugaredValue> createConstantSugaredValue (py::object obj) {
227
- if (py::isinstance<Module>(obj)) {
228
- auto r = py::cast<std::shared_ptr<Module>>(obj);
229
- return std::make_shared<ModuleValue>(r);
230
- }
231
- return std::make_shared<ConstantPythonValue>(obj);
232
- }
233
227
234
228
// TODO: dedup with other init
235
229
0 commit comments