diff --git a/core/lowering/passes/module_fallback.cpp b/core/lowering/passes/module_fallback.cpp index 248f45edcb..415b385634 100644 --- a/core/lowering/passes/module_fallback.cpp +++ b/core/lowering/passes/module_fallback.cpp @@ -43,12 +43,17 @@ void NotateModuleForFallback( "Notating module for fallback: " << n->s(c10::attr::name) << " (" << out_type << ") [owner: " << mod_name << " (" << cls_name << ")]"); auto uses = n->output(0)->uses(); + int k = 0; for (const auto u : uses) { + auto compilation_context_node = g->createNone(); + auto compilation_context = compilation_context_node->outputs()[0]; + compilation_context->setDebugName("compilation_context_" + std::to_string(k++)); auto user = u.user; - auto delim_start_n = g->create(torch::jit::prim::Enter, 0); + auto delim_start_n = g->create(torch::jit::prim::Enter, {compilation_context}); delim_start_n->s_(c10::Symbol::attr("compilation_edge"), "start"); - auto delim_end_n = g->create(torch::jit::prim::Exit, 0); + auto delim_end_n = g->create(torch::jit::prim::Exit, {compilation_context}); delim_end_n->s_(c10::Symbol::attr("compilation_edge"), "end"); + compilation_context_node->insertBefore(user); delim_start_n->insertBefore(user); delim_end_n->insertAfter(user); }