Skip to content

Commit 6bd1a3f

Browse files
committed
fix(//core): Do not compile hidden methods
(methods prefixed with _) Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent dcb1474 commit 6bd1a3f

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

core/compiler.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,16 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod,
150150
torch::jit::script::Module new_mod(mod._ivalue()->name() + "_trt");
151151
std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
152152
for (const torch::jit::script::Method& method : mod.get_methods()) {
153-
auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg);
154-
auto new_g = std::make_shared<torch::jit::Graph>();
155-
AddEngineToGraph(new_mod, new_g, engine);
156-
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
157-
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
158-
new_mod.type()->addMethod(new_method);
159-
new_method->setSchema(schema);
153+
// Don't convert hidden methods
154+
if (method.name().rfind("_", 0)) {
155+
auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg);
156+
auto new_g = std::make_shared<torch::jit::Graph>();
157+
AddEngineToGraph(new_mod, new_g, engine);
158+
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
159+
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
160+
new_mod.type()->addMethod(new_method);
161+
new_method->setSchema(schema);
162+
}
160163
}
161164

162165
return new_mod;

0 commit comments

Comments
 (0)