@@ -150,13 +150,16 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod,
150
150
torch::jit::script::Module new_mod (mod._ivalue ()->name () + " _trt" );
151
151
std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
152
152
for (const torch::jit::script::Method& method : mod.get_methods ()) {
153
- auto engine = ConvertGraphToTRTEngine (mod, method.name (), cfg);
154
- auto new_g = std::make_shared<torch::jit::Graph>();
155
- AddEngineToGraph (new_mod, new_g, engine);
156
- auto new_method = new_mod._ivalue ()->compilation_unit ()->create_function (method.name (), new_g);
157
- auto schema = GenerateGraphSchema (new_mod, new_method->name (), new_g);
158
- new_mod.type ()->addMethod (new_method);
159
- new_method->setSchema (schema);
153
+ // Don't convert hidden methods
154
+ if (method.name ().rfind (" _" , 0 )) {
155
+ auto engine = ConvertGraphToTRTEngine (mod, method.name (), cfg);
156
+ auto new_g = std::make_shared<torch::jit::Graph>();
157
+ AddEngineToGraph (new_mod, new_g, engine);
158
+ auto new_method = new_mod._ivalue ()->compilation_unit ()->create_function (method.name (), new_g);
159
+ auto schema = GenerateGraphSchema (new_mod, new_method->name (), new_g);
160
+ new_mod.type ()->addMethod (new_method);
161
+ new_method->setSchema (schema);
162
+ }
160
163
}
161
164
162
165
return new_mod;
0 commit comments