diff --git a/test/test_jit_cuda_fuser.py b/test/test_jit_cuda_fuser.py index 7d20c0ac383b2..5de1ed0e54b17 100644 --- a/test/test_jit_cuda_fuser.py +++ b/test/test_jit_cuda_fuser.py @@ -13,11 +13,13 @@ from test_jit import JitTestCase, RUN_CUDA - if GRAPH_EXECUTOR == ProfilingMode.PROFILING: torch._C._jit_set_profiling_executor(True) torch._C._jit_set_profiling_mode(True) +FUSION_GROUP = 'prim::CudaFusionGroup' + + class TestCudaFuser(JitTestCase): def setUp(self): @@ -28,16 +30,16 @@ def setUp(self): torch._C._jit_override_can_fuse_on_gpu(False) if(RUN_CUDA): - torch._C._jit_register_cuda_fuser() + self.old_nvfuser = torch._C._jit_set_nvfuser_enabled(True) def tearDown(self): if(RUN_CUDA): - torch._C._jit_clear_cuda_fuser() + torch._C._jit_set_nvfuser_enabled(self.old_nvfuser) torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuse) torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuse) super(TestCudaFuser, self).tearDown() - def _run_helper(self, jit_op, op, should_fuse, *args): + def _run_helper(self, jit_op, op, *args): torch.cuda.manual_seed_all(123) jit_o = jit_op(*args) torch.cuda.manual_seed_all(123) @@ -45,20 +47,14 @@ def _run_helper(self, jit_op, op, should_fuse, *args): torch.cuda.manual_seed_all(123) o = op(*args) self.assertEqual(o, jit_o) - self.assertTrue(self._has_cuda_fusion_group(jit_op.graph_for(*args)) == should_fuse) - - def _has_cuda_fusion_group(self, graph): - has_cuda_fusion_group = False - for n in graph.nodes(): - if n.kind() == 'prim::CudaFusionGroup': - has_cuda_fusion_group = True - return has_cuda_fusion_group + self.assertGraphContains(jit_op.graph_for(*args), FUSION_GROUP) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != + ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") @skipIfRocm def test_half(self): - def t(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor, alpha : float): + def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, alpha: float): o_16 = torch.add(x, y) o_32_a = torch.add(y, z, alpha=alpha) o_32_b = torch.add(o_16, z) @@ -77,10 +73,11 @@ def t(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor, alpha : float): for oo, jit_oo in zip(o, jit_o): self.assertEqual(oo.dtype, jit_oo.dtype) self.assertEqual(oo, jit_oo) - self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, z, alpha))) + self.assertGraphContains(t_jit.graph_for(x, y, z, alpha), FUSION_GROUP) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != + ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") @skipIfRocm def test_const(self): def t(x, y): @@ -94,10 +91,11 @@ def t(x, y): jit_o = t_jit(x, y) o = t(x, y) self.assertEqual(o, jit_o) - self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y))) + self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GROUP) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != + ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") @skipIfRocm def test_chunk(self): def t(x, y, z, q): @@ -117,13 +115,14 @@ def t(x, y, z, q): jit_o = t_jit(x, y, z, q) o = t(x, y, z, q) self.assertEqual(o, jit_o) - self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, z, q))) + self.assertGraphContains(t_jit.graph_for(x, y, z, q), FUSION_GROUP) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != + ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") @skipIfRocm def test_scalar_input(self): - def t(x : torch.Tensor, y : torch.Tensor, z : float): + def t(x: torch.Tensor, y: torch.Tensor, z: float): o = x + y o = o + z return o @@ -135,13 +134,14 @@ def t(x : torch.Tensor, y : torch.Tensor, z : float): jit_o = t_jit(x, y, 2.0) o = t(x, y, 2.0) self.assertEqual(o, jit_o) - self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, 2.0))) + self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GROUP) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != + ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") @skipIfRocm def test_broadcasting(self): - def t(x : torch.Tensor, y : torch.Tensor, z : float): + def t(x: torch.Tensor, y: torch.Tensor, z: float): o = x + y o = o + z return o @@ -152,14 +152,15 @@ def t(x : torch.Tensor, y : torch.Tensor, z : float): jit_o = t_jit(x, y, 2.0) o = t(x, y, 2.0) self.assertEqual(o, jit_o) - self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, 2.0))) + self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GROUP) @unittest.skipIf(True, "real broadcast with different output not supported yet") @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != + ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") @skipIfRocm def test_broadcasting_multiple_output_shape(self): - def t(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor): + def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): o = x + 12 o1 = o + y o2 = o + z @@ -174,13 +175,14 @@ def t(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor): o = t(x, y, z) self.assertEqual(o, jit_o) # Currently cannot fuse this - self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, z))) + self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GROUP) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != + ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") @skipIfRocm def test_broadcasting_multiple_output(self): - def t(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor): + def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): o = x + 12 o1 = o + y o2 = o + z @@ -195,10 +197,10 @@ def t(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor): o = t(x, y, z) self.assertEqual(o, jit_o) # Currently cannot fuse this - self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, z))) + self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GROUP) def _binary_test_helper(self, operation): - def t(x : torch.Tensor, y: torch.Tensor, z : float): + def t(x: torch.Tensor, y: torch.Tensor, z: float): o = x + z o = operation(o, y) return o @@ -209,10 +211,10 @@ def t(x : torch.Tensor, y: torch.Tensor, z : float): jit_o = t_jit(x, y, 2.0) o = t(x, y, 2.0) self.assertEqual(o, jit_o) - self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, 2.0))) + self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GROUP) def _unary_test_helper(self, operation): - def t(x : torch.Tensor, z : float): + def t(x: torch.Tensor, z: float): o = x + z o = operation(o) return o @@ -222,10 +224,11 @@ def t(x : torch.Tensor, z : float): jit_o = t_jit(x, 2.0) o = t(x, 2.0) self.assertEqual(o, jit_o) - self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, 2.0))) + self.assertGraphContains(t_jit.graph_for(x, 2.0), FUSION_GROUP) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != + ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") @skipIfRocm def test_unary_ops(self): operations = [torch.neg, @@ -262,7 +265,8 @@ def test_unary_ops(self): self._unary_test_helper(op) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != + ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") @skipIfRocm def test_binary_ops(self): operations = [torch.div, @@ -283,7 +287,8 @@ def test_binary_ops(self): self._binary_test_helper(op) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser") + # legacy fuser does not work for rand_like, see issue #34361 + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") @skipIfRocm def test_ternary_ops(self): x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") @@ -291,54 +296,54 @@ def test_ternary_ops(self): z = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda") cond = torch.randint(0, 2, (4, 8, 32, 32)).to(dtype=torch.bool, device="cuda") - def add(x : torch.Tensor, other : torch.Tensor, alpha : float): + def add(x: torch.Tensor, other: torch.Tensor, alpha: float): o = torch.relu(x) o = torch.add(o, other=other, alpha=alpha) return o add_jit = torch.jit.script(add) - self._run_helper(add_jit, add, True, x, y, 2.0) + self._run_helper(add_jit, add, x, y, 2.0) - def clamp0(x : torch.Tensor, f : float): + def clamp0(x: torch.Tensor, f: float): o = torch.rand_like(x) o = o * torch.clamp(x, min=f) return o clamp0_jit = torch.jit.script(clamp0) - self._run_helper(clamp0_jit, clamp0, True, x, 0.5) + self._run_helper(clamp0_jit, clamp0, x, 0.5) - def clamp1(x : torch.Tensor, f : float, ff : float): + def clamp1(x: torch.Tensor, f: float, ff: float): o = torch.rand_like(x) o = o * torch.clamp(x, min=f, max=ff) return o clamp1_jit = torch.jit.script(clamp1) - self._run_helper(clamp1_jit, clamp1, True, x, -0.2, 0.7) + self._run_helper(clamp1_jit, clamp1, x, -0.2, 0.7) - def threshold(x : torch.Tensor, th : float, val : float): + def threshold(x: torch.Tensor, th: float, val: float): o = torch.rand_like(x) o = x * torch.threshold(o, th, val) return o threshold_jit = torch.jit.script(threshold) - self._run_helper(threshold_jit, threshold, True, x, 0.2, 0.9) + self._run_helper(threshold_jit, threshold, x, 0.2, 0.9) - def where(x : torch.Tensor, y : torch.Tensor, cond : torch.Tensor): + def where(x: torch.Tensor, y: torch.Tensor, cond: torch.Tensor): o = torch.rand_like(x) o = o * torch.where(cond, x, y) return o where_jit = torch.jit.script(where) - self._run_helper(where_jit, where, True, x, y, cond) + self._run_helper(where_jit, where, x, y, cond) def lerp(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor): o = torch.rand_like(x) o = o * torch.lerp(x, y, z) return o lerp_jit = torch.jit.script(lerp) - self._run_helper(lerp_jit, lerp, True, x, y, z) + self._run_helper(lerp_jit, lerp, x, y, z) def lerp_scale(x : torch.Tensor, y : torch.Tensor, z: float): o = torch.rand_like(x) o = o * torch.lerp(x, y, z) return o lerp_scale_jit = torch.jit.script(lerp_scale) - self._run_helper(lerp_scale_jit, lerp_scale, True, x, y, 0.5) + self._run_helper(lerp_scale_jit, lerp_scale, x, y, 0.5) @unittest.skipIf(not RUN_CUDA, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser") @@ -353,27 +358,28 @@ def addcmul(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor, value : float) o = torch.addcmul(o, y, z, value=value) return o addcmul_jit = torch.jit.script(addcmul) - self._run_helper(addcmul_jit, addcmul, True, x, y, z, 2.0) + self._run_helper(addcmul_jit, addcmul, x, y, z, 2.0) def addcmul_no_alpha(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor): o = torch.add(x, 0.5) o = torch.addcmul(o, y, z) return o addcmul_no_alpha_jit = torch.jit.script(addcmul_no_alpha) - self._run_helper(addcmul_no_alpha_jit, addcmul_no_alpha, True, x, y, z) + self._run_helper(addcmul_no_alpha_jit, addcmul_no_alpha, x, y, z) def addcmul_const_alpha(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor): o = torch.add(x, 0.5) o = torch.addcmul(o, y, z, value=0.75) return o addcmul_const_alpha_jit = torch.jit.script(addcmul_const_alpha) - self._run_helper(addcmul_const_alpha_jit, addcmul_const_alpha, True, x, y, z) + self._run_helper(addcmul_const_alpha_jit, addcmul_const_alpha, x, y, z) @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Requires profiling node to run cuda fuser") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != + ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") @skipIfRocm def test_dynamic_size(self): - def t(x : torch.Tensor, y : torch.Tensor, z : float): + def t(x: torch.Tensor, y: torch.Tensor, z: float): o = x + y o = o + z return o @@ -384,7 +390,7 @@ def t(x : torch.Tensor, y : torch.Tensor, z : float): jit_o = t_jit(x, y, 2.0) o = t(x, y, 2.0) self.assertEqual(o, jit_o) - self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, 2.0))) + self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GROUP) x = torch.randn(8, 32, 16, 8, dtype=torch.float, device="cuda") y = torch.randn(16, 8, dtype=torch.float, device="cuda") jit_o = t_jit(x, y, 2.0) @@ -400,5 +406,55 @@ def test_random_topo(self): os.environ["PYTORCH_CUDA_FUSER_DISABLE_FALLBACK"] = "1" self.assertTrue(runDefaultTestWithSeed(28449)) + +class TestPassManagerCudaFuser(JitTestCase): + + @unittest.skipIf(not RUN_CUDA, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != + ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") + @skipIfRocm + def test_context_manager_test(self): + x = torch.randn(4, 8, dtype=torch.float, device="cuda") + y = torch.randn(4, 8, dtype=torch.float, device="cuda") + with torch.jit.fuser('fuser2'): + with torch.jit.fuser('fuser2'): + + def t1(x, y): + o = x + y + o = o + 2.0 + return o + t_jit = torch.jit.script(t1) + t_jit(x, y) + t_jit(x, y) + self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GROUP) + + def t2(x, y): + o = x + y + o = o + 3.0 + return o + t_jit_2 = torch.jit.script(t2) + t_jit_2(x, y) + t_jit_2(x, y) + self.assertGraphContains(t_jit_2.graph_for(x, y), FUSION_GROUP) + + def t3(x, y): + o = x + y + o = o + 4.0 + return o + t_jit_3 = torch.jit.script(t3) + t_jit_3(x, y) + t_jit_3(x, y) + self.assertGraphContainsExactly(t_jit_3.graph_for(x, y), FUSION_GROUP, 0) + + @unittest.skipIf(not RUN_CUDA, "requires CUDA") + @skipIfRocm + def test_register_fuser(self): + self.assertFalse(torch._C._jit_set_nvfuser_enabled(True)) + self.assertTrue(torch._C._jit_nvfuser_enabled()) + self.assertTrue(torch._C._jit_set_nvfuser_enabled(True)) + self.assertTrue(torch._C._jit_nvfuser_enabled()) + self.assertTrue(torch._C._jit_set_nvfuser_enabled(False)) + self.assertFalse(torch._C._jit_nvfuser_enabled()) + if __name__ == '__main__': run_tests() diff --git a/torch/csrc/jit/passes/cuda_graph_fuser.h b/torch/csrc/jit/passes/cuda_graph_fuser.h index c227f2b0fafb8..0f821845613b2 100644 --- a/torch/csrc/jit/passes/cuda_graph_fuser.h +++ b/torch/csrc/jit/passes/cuda_graph_fuser.h @@ -11,11 +11,21 @@ namespace jit { // Register CudaFuseGraph in custom passes struct C10_EXPORT RegisterCudaFuseGraph : public PassManager { - static void registerPass() { + static bool registerPass(bool enabled) { TORCH_CHECK( at::globalContext().hasCUDA() && !at::globalContext().hasHIP(), "Running CUDA fuser is only supported on CUDA builds."); - PassManager::registerPass(fuser::cuda::fuseGraph); + bool old_flag = PassManager::isRegistered(); + if (enabled) { + PassManager::registerPass(fuser::cuda::fuseGraph); + } else { + PassManager::clearPass(); + } + return old_flag; + } + + static bool isRegistered() { + return PassManager::isRegistered(); } }; diff --git a/torch/csrc/jit/passes/pass_manager.h b/torch/csrc/jit/passes/pass_manager.h index bbd6452ca58ff..73109f3d25460 100644 --- a/torch/csrc/jit/passes/pass_manager.h +++ b/torch/csrc/jit/passes/pass_manager.h @@ -108,14 +108,17 @@ struct C10_EXPORT PassManager { public: // registerPass(pass) will register the pass provided and set the - // name/isRegistered functions appropriately - static void registerPass(GraphPass p) { + // name/isRegistered functions appropriately, it returns a bool value + // indicating whether the given pass is already registered previously. + static bool registerPass(GraphPass p) { if (!isRegistered()) { // If we don't already have a registered pass, register pass // hold on to its name, change isRegistered to true passID(registerPostPass(std::move(p)), true); isRegistered(true); + return false; } + return true; } // Calls ClearPostPass(passID()) diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index d5ace78ff1b55..8fb80776dc31b 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -400,8 +400,8 @@ void initJITBindings(PyObject* module) { auto stack = toTraceableStack(args); checkAliasAnnotation(g, std::move(stack), unqualified_op_name); }) - .def("_jit_register_cuda_fuser", &RegisterCudaFuseGraph::registerPass) - .def("_jit_clear_cuda_fuser", &RegisterCudaFuseGraph::clearPass) + .def("_jit_set_nvfuser_enabled", &RegisterCudaFuseGraph::registerPass) + .def("_jit_nvfuser_enabled", &RegisterCudaFuseGraph::isRegistered) .def( "_jit_set_profiling_mode", [](bool profiling_flag) { diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index 06f9810cadf5f..f9bad44c96418 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -84,6 +84,42 @@ def optimized_execution(should_optimize): finally: torch._C._set_graph_executor_optimize(stored_flag) +@contextlib.contextmanager +def fuser(name): + old_cpu_fuse = torch._C._jit_can_fuse_on_cpu() + old_gpu_fuse = torch._C._jit_can_fuse_on_gpu() + old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled() + old_nvfuser_state = torch._C._jit_nvfuser_enabled() + if name == 'fuser0': # legacy fuser + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_nvfuser_enabled(False) + elif name == 'fuser1': # NNC + old_profiling_executor = torch._C._jit_set_profiling_executor(True) + old_profiling_mode = torch._C._jit_set_profiling_mode(True) + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(True) + torch._C._jit_set_nvfuser_enabled(False) + elif name == 'fuser2': # nvFuser + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_nvfuser_enabled(True) + else: + raise Exception("unrecognized fuser option") + try: + yield + finally: + if name == 'fuser1': # NNC + torch._C._jit_set_profiling_executor(old_profiling_executor) + torch._C._jit_set_profiling_mode(old_profiling_mode) + # recover the previous values + torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse) + torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse) + torch._C._jit_set_texpr_fuser_enabled(old_texpr_fuser_state) + torch._C._jit_set_nvfuser_enabled(old_nvfuser_state) DEFAULT_EXTRA_FILES_MAP = torch._C.ExtraFilesMap() diff --git a/torch/testing/_internal/codegen/random_topo_test.py b/torch/testing/_internal/codegen/random_topo_test.py index fa339abaa581e..41e6917ef1354 100644 --- a/torch/testing/_internal/codegen/random_topo_test.py +++ b/torch/testing/_internal/codegen/random_topo_test.py @@ -358,7 +358,7 @@ def parse_args(): # Register CUDA fuser if args.cuda_fuser: - torch._C._jit_register_cuda_fuser() + torch._C._jit_set_nvfuser_enabled(True) # Turn off legacy fuser if not args.legacy_fuser: