Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 112 additions & 56 deletions test/test_jit_cuda_fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@

from test_jit import JitTestCase, RUN_CUDA


if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)

FUSION_GROUP = 'prim::CudaFusionGroup'


class TestCudaFuser(JitTestCase):

def setUp(self):
Expand All @@ -28,37 +30,31 @@ def setUp(self):
torch._C._jit_override_can_fuse_on_gpu(False)

if(RUN_CUDA):
torch._C._jit_register_cuda_fuser()
self.old_nvfuser = torch._C._jit_set_nvfuser_enabled(True)

def tearDown(self):
if(RUN_CUDA):
torch._C._jit_clear_cuda_fuser()
torch._C._jit_set_nvfuser_enabled(self.old_nvfuser)
torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuse)
torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuse)
super(TestCudaFuser, self).tearDown()

def _run_helper(self, jit_op, op, should_fuse, *args):
def _run_helper(self, jit_op, op, *args):
torch.cuda.manual_seed_all(123)
jit_o = jit_op(*args)
torch.cuda.manual_seed_all(123)
jit_o = jit_op(*args)
torch.cuda.manual_seed_all(123)
o = op(*args)
self.assertEqual(o, jit_o)
self.assertTrue(self._has_cuda_fusion_group(jit_op.graph_for(*args)) == should_fuse)

def _has_cuda_fusion_group(self, graph):
has_cuda_fusion_group = False
for n in graph.nodes():
if n.kind() == 'prim::CudaFusionGroup':
has_cuda_fusion_group = True
return has_cuda_fusion_group
self.assertGraphContains(jit_op.graph_for(*args), FUSION_GROUP)

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR !=
ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective")
@skipIfRocm
def test_half(self):
def t(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor, alpha : float):
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, alpha: float):
o_16 = torch.add(x, y)
o_32_a = torch.add(y, z, alpha=alpha)
o_32_b = torch.add(o_16, z)
Expand All @@ -77,10 +73,11 @@ def t(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor, alpha : float):
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, z, alpha)))
self.assertGraphContains(t_jit.graph_for(x, y, z, alpha), FUSION_GROUP)

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR !=
ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective")
@skipIfRocm
def test_const(self):
def t(x, y):
Expand All @@ -94,10 +91,11 @@ def t(x, y):
jit_o = t_jit(x, y)
o = t(x, y)
self.assertEqual(o, jit_o)
self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y)))
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GROUP)

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR !=
ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective")
@skipIfRocm
def test_chunk(self):
def t(x, y, z, q):
Expand All @@ -117,13 +115,14 @@ def t(x, y, z, q):
jit_o = t_jit(x, y, z, q)
o = t(x, y, z, q)
self.assertEqual(o, jit_o)
self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, z, q)))
self.assertGraphContains(t_jit.graph_for(x, y, z, q), FUSION_GROUP)

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR !=
ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective")
@skipIfRocm
def test_scalar_input(self):
def t(x : torch.Tensor, y : torch.Tensor, z : float):
def t(x: torch.Tensor, y: torch.Tensor, z: float):
o = x + y
o = o + z
return o
Expand All @@ -135,13 +134,14 @@ def t(x : torch.Tensor, y : torch.Tensor, z : float):
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, 2.0)))
self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GROUP)

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR !=
ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective")
@skipIfRocm
def test_broadcasting(self):
def t(x : torch.Tensor, y : torch.Tensor, z : float):
def t(x: torch.Tensor, y: torch.Tensor, z: float):
o = x + y
o = o + z
return o
Expand All @@ -152,14 +152,15 @@ def t(x : torch.Tensor, y : torch.Tensor, z : float):
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, 2.0)))
self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GROUP)

@unittest.skipIf(True, "real broadcast with different output not supported yet")
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR !=
ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective")
@skipIfRocm
def test_broadcasting_multiple_output_shape(self):
def t(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor):
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = x + 12
o1 = o + y
o2 = o + z
Expand All @@ -174,13 +175,14 @@ def t(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor):
o = t(x, y, z)
self.assertEqual(o, jit_o)
# Currently cannot fuse this
self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, z)))
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GROUP)

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR !=
ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective")
@skipIfRocm
def test_broadcasting_multiple_output(self):
def t(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor):
def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
o = x + 12
o1 = o + y
o2 = o + z
Expand All @@ -195,10 +197,10 @@ def t(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor):
o = t(x, y, z)
self.assertEqual(o, jit_o)
# Currently cannot fuse this
self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, z)))
self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GROUP)

def _binary_test_helper(self, operation):
def t(x : torch.Tensor, y: torch.Tensor, z : float):
def t(x: torch.Tensor, y: torch.Tensor, z: float):
o = x + z
o = operation(o, y)
return o
Expand All @@ -209,10 +211,10 @@ def t(x : torch.Tensor, y: torch.Tensor, z : float):
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, 2.0)))
self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GROUP)

def _unary_test_helper(self, operation):
def t(x : torch.Tensor, z : float):
def t(x: torch.Tensor, z: float):
o = x + z
o = operation(o)
return o
Expand All @@ -222,10 +224,11 @@ def t(x : torch.Tensor, z : float):
jit_o = t_jit(x, 2.0)
o = t(x, 2.0)
self.assertEqual(o, jit_o)
self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, 2.0)))
self.assertGraphContains(t_jit.graph_for(x, 2.0), FUSION_GROUP)

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR !=
ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective")
@skipIfRocm
def test_unary_ops(self):
operations = [torch.neg,
Expand Down Expand Up @@ -262,7 +265,8 @@ def test_unary_ops(self):
self._unary_test_helper(op)

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR !=
ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective")
@skipIfRocm
def test_binary_ops(self):
operations = [torch.div,
Expand All @@ -283,62 +287,63 @@ def test_binary_ops(self):
self._binary_test_helper(op)

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
# legacy fuser does not work for rand_like, see issue #34361
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective")
@skipIfRocm
def test_ternary_ops(self):
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
z = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
cond = torch.randint(0, 2, (4, 8, 32, 32)).to(dtype=torch.bool, device="cuda")

def add(x : torch.Tensor, other : torch.Tensor, alpha : float):
def add(x: torch.Tensor, other: torch.Tensor, alpha: float):
o = torch.relu(x)
o = torch.add(o, other=other, alpha=alpha)
return o
add_jit = torch.jit.script(add)
self._run_helper(add_jit, add, True, x, y, 2.0)
self._run_helper(add_jit, add, x, y, 2.0)

def clamp0(x : torch.Tensor, f : float):
def clamp0(x: torch.Tensor, f: float):
o = torch.rand_like(x)
o = o * torch.clamp(x, min=f)
return o
clamp0_jit = torch.jit.script(clamp0)
self._run_helper(clamp0_jit, clamp0, True, x, 0.5)
self._run_helper(clamp0_jit, clamp0, x, 0.5)

def clamp1(x : torch.Tensor, f : float, ff : float):
def clamp1(x: torch.Tensor, f: float, ff: float):
o = torch.rand_like(x)
o = o * torch.clamp(x, min=f, max=ff)
return o
clamp1_jit = torch.jit.script(clamp1)
self._run_helper(clamp1_jit, clamp1, True, x, -0.2, 0.7)
self._run_helper(clamp1_jit, clamp1, x, -0.2, 0.7)

def threshold(x : torch.Tensor, th : float, val : float):
def threshold(x: torch.Tensor, th: float, val: float):
o = torch.rand_like(x)
o = x * torch.threshold(o, th, val)
return o
threshold_jit = torch.jit.script(threshold)
self._run_helper(threshold_jit, threshold, True, x, 0.2, 0.9)
self._run_helper(threshold_jit, threshold, x, 0.2, 0.9)

def where(x : torch.Tensor, y : torch.Tensor, cond : torch.Tensor):
def where(x: torch.Tensor, y: torch.Tensor, cond: torch.Tensor):
o = torch.rand_like(x)
o = o * torch.where(cond, x, y)
return o
where_jit = torch.jit.script(where)
self._run_helper(where_jit, where, True, x, y, cond)
self._run_helper(where_jit, where, x, y, cond)

def lerp(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor):
o = torch.rand_like(x)
o = o * torch.lerp(x, y, z)
return o
lerp_jit = torch.jit.script(lerp)
self._run_helper(lerp_jit, lerp, True, x, y, z)
self._run_helper(lerp_jit, lerp, x, y, z)

def lerp_scale(x : torch.Tensor, y : torch.Tensor, z: float):
o = torch.rand_like(x)
o = o * torch.lerp(x, y, z)
return o
lerp_scale_jit = torch.jit.script(lerp_scale)
self._run_helper(lerp_scale_jit, lerp_scale, True, x, y, 0.5)
self._run_helper(lerp_scale_jit, lerp_scale, x, y, 0.5)

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
Expand All @@ -353,27 +358,28 @@ def addcmul(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor, value : float)
o = torch.addcmul(o, y, z, value=value)
return o
addcmul_jit = torch.jit.script(addcmul)
self._run_helper(addcmul_jit, addcmul, True, x, y, z, 2.0)
self._run_helper(addcmul_jit, addcmul, x, y, z, 2.0)

def addcmul_no_alpha(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor):
o = torch.add(x, 0.5)
o = torch.addcmul(o, y, z)
return o
addcmul_no_alpha_jit = torch.jit.script(addcmul_no_alpha)
self._run_helper(addcmul_no_alpha_jit, addcmul_no_alpha, True, x, y, z)
self._run_helper(addcmul_no_alpha_jit, addcmul_no_alpha, x, y, z)

def addcmul_const_alpha(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor):
o = torch.add(x, 0.5)
o = torch.addcmul(o, y, z, value=0.75)
return o
addcmul_const_alpha_jit = torch.jit.script(addcmul_const_alpha)
self._run_helper(addcmul_const_alpha_jit, addcmul_const_alpha, True, x, y, z)
self._run_helper(addcmul_const_alpha_jit, addcmul_const_alpha, x, y, z)

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Requires profiling node to run cuda fuser")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR !=
ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective")
@skipIfRocm
def test_dynamic_size(self):
def t(x : torch.Tensor, y : torch.Tensor, z : float):
def t(x: torch.Tensor, y: torch.Tensor, z: float):
o = x + y
o = o + z
return o
Expand All @@ -384,7 +390,7 @@ def t(x : torch.Tensor, y : torch.Tensor, z : float):
jit_o = t_jit(x, y, 2.0)
o = t(x, y, 2.0)
self.assertEqual(o, jit_o)
self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, 2.0)))
self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GROUP)
x = torch.randn(8, 32, 16, 8, dtype=torch.float, device="cuda")
y = torch.randn(16, 8, dtype=torch.float, device="cuda")
jit_o = t_jit(x, y, 2.0)
Expand All @@ -400,5 +406,55 @@ def test_random_topo(self):
os.environ["PYTORCH_CUDA_FUSER_DISABLE_FALLBACK"] = "1"
self.assertTrue(runDefaultTestWithSeed(28449))


class TestPassManagerCudaFuser(JitTestCase):

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR !=
ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective")
@skipIfRocm
def test_context_manager_test(self):
x = torch.randn(4, 8, dtype=torch.float, device="cuda")
y = torch.randn(4, 8, dtype=torch.float, device="cuda")
with torch.jit.fuser('fuser2'):
with torch.jit.fuser('fuser2'):

def t1(x, y):
o = x + y
o = o + 2.0
return o
t_jit = torch.jit.script(t1)
t_jit(x, y)
t_jit(x, y)
self.assertGraphContains(t_jit.graph_for(x, y), FUSION_GROUP)

def t2(x, y):
o = x + y
o = o + 3.0
return o
t_jit_2 = torch.jit.script(t2)
t_jit_2(x, y)
t_jit_2(x, y)
self.assertGraphContains(t_jit_2.graph_for(x, y), FUSION_GROUP)

def t3(x, y):
o = x + y
o = o + 4.0
return o
t_jit_3 = torch.jit.script(t3)
t_jit_3(x, y)
t_jit_3(x, y)
self.assertGraphContainsExactly(t_jit_3.graph_for(x, y), FUSION_GROUP, 0)

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
@skipIfRocm
def test_register_fuser(self):
self.assertFalse(torch._C._jit_set_nvfuser_enabled(True))
self.assertTrue(torch._C._jit_nvfuser_enabled())
self.assertTrue(torch._C._jit_set_nvfuser_enabled(True))
self.assertTrue(torch._C._jit_nvfuser_enabled())
self.assertTrue(torch._C._jit_set_nvfuser_enabled(False))
self.assertFalse(torch._C._jit_nvfuser_enabled())

if __name__ == '__main__':
run_tests()
Loading