Skip to content

Commit e019026

Browse files
bertmaherMikhail Zolotukhin
authored and
Mikhail Zolotukhin
committed
Modify existing fuser tests to suite TE fuser (pytorch#232)
1 parent dd7d232 commit e019026

File tree

2 files changed

+33
-24
lines changed

2 files changed

+33
-24
lines changed

test/test_jit_fuser.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
torch._C._jit_set_profiling_executor(True)
2323
torch._C._jit_set_profiling_mode(True)
2424

25+
FUSION_GROUP = 'tensorexpr::Group'
2526

2627
def strip_profiling_nodes(nodes):
2728
profiling_opcodes = set(['prim::BailoutTemplate', 'prim::BailOut'])
@@ -57,11 +58,11 @@ def assertAllFused(self, graph, except_for=()):
5758
self.assertEqual(len(diff_graphs), 1)
5859
graph = diff_graphs[0].g('Subgraph')
5960

60-
allowed_nodes = {'prim::Constant', 'prim::FusionGroup', 'prim::BailoutTemplate',
61+
allowed_nodes = {'prim::Constant', FUSION_GROUP, 'prim::BailoutTemplate',
6162
'prim::BailOut', 'prim::TupleConstruct'} | set(except_for)
6263
self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()),
6364
'got {}'.format(graph))
64-
self.assertTrue([node.kind() for node in graph.nodes()].count('prim::FusionGroup') == 1)
65+
self.assertTrue([node.kind() for node in graph.nodes()].count(FUSION_GROUP) == 1)
6566

6667
def _test_fused_abs(self, device='cpu'):
6768
def func(x):
@@ -72,25 +73,31 @@ def func(x):
7273
self.assertAllFused(scripted.graph_for(a))
7374

7475
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
75-
@enable_cpu_fuser
7676
def test_abs_cpu(self):
7777
self._test_fused_abs()
7878

7979
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
8080
def test_abs_cuda(self):
8181
self._test_fused_abs(device="cuda")
8282

83-
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
84-
def test_zero_element_tensors(self):
83+
def _test_zero_element_tensors(self, device="cpu"):
8584
def decode(sin_t, cos_t):
8685
theta = torch.atan2(sin_t.float(), cos_t.float())
8786
return theta
8887

89-
sin = torch.zeros(0, device="cuda")
90-
cos = torch.zeros(0, device="cuda")
88+
sin = torch.zeros(0, device=device)
89+
cos = torch.zeros(0, device=device)
9190
inputs = [sin, cos]
9291
ge = self.checkScript(decode, inputs)
9392

93+
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
94+
def test_zero_element_tensors_cuda(self):
95+
self._test_zero_element_tensors(device="cuda")
96+
97+
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
98+
def test_zero_element_tensors_cpu(self):
99+
self._test_zero_element_tensors(device="cpu")
100+
94101
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
95102
def test_arg_configurations_smoke_cuda(self):
96103
# A smoke test to make sure we won't use the same kernel for contiguous
@@ -216,7 +223,6 @@ def chunk_4_last(x):
216223
self.checkScript(fn, [tensor])
217224

218225
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
219-
@enable_cpu_fuser
220226
def test_chunk_correctness(self):
221227
return self._test_chunk_correctness(self, 'cpu')
222228

@@ -235,7 +241,7 @@ def f(x, y):
235241

236242
ge = self.checkTrace(f, (x, y))
237243
graph = ge.graph_for(x, y)
238-
FileCheck().check("broadcast_tensors").check('with prim::FusionGroup_') \
244+
FileCheck().check("broadcast_tensors").check('with ' + FUSION_GROUP + '_') \
239245
.check_count('ConstantChunk', 2, exactly=True).run(str(graph))
240246

241247
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@@ -256,7 +262,7 @@ def func2(x):
256262
for func in [func1, func2]:
257263
module = self.checkScript(func, inputs)
258264
forward_graph = module.graph_for(*inputs)
259-
self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1)
265+
self.assertGraphContainsExactly(forward_graph, FUSION_GROUP, 1)
260266
fusion_group = list(forward_graph.nodes())[-1]
261267
self.assertEqual(len(list(fusion_group.inputs())), 1)
262268

@@ -498,7 +504,7 @@ def test_norm_decompose(nm, in_opt_graph, not_in_opt_graph, in_fusegraph):
498504
self.assertNotIn(node_not_in_graph, rep)
499505
self.assertIn(node_not_in_graph, rep_noopt)
500506

501-
fusion_groups = [node for node in graph.nodes() if node.kind() == 'prim::FusionGroup']
507+
fusion_groups = [node for node in graph.nodes() if node.kind() == FUSION_GROUP]
502508
self.assertEqual(len(fusion_groups), 1)
503509
fused_graph = str(fusion_groups[0].g('Subgraph'))
504510
for node_in_fusegraph in in_fusegraph:
@@ -549,7 +555,6 @@ def fn_test_scalar_arg_requires_grad(x, p):
549555

550556
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
551557
@unittest.skip("deduplicating introduces aliasing in backward graph's outputs")
552-
@enable_cpu_fuser
553558
def test_fuser_deduplication(self):
554559
# See that fusion kernel outputs are deduplicated when removing _grad_sum_to_size in the fuser's compilation
555560
# see the discussion in PR #14957.
@@ -571,7 +576,6 @@ def f(x, y):
571576
self.assertEqual(ga2.data_ptr(), gb2.data_ptr())
572577

573578
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
574-
@enable_cpu_fuser
575579
@unittest.skip("temporarily disabled because fusion was restricted in fixing #22833")
576580
def test_fuser_iou(self):
577581
# This checks if most of Intersection over Union is fused.
@@ -615,7 +619,6 @@ def iou(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2):
615619

616620
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
617621
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
618-
@enable_cpu_fuser
619622
def test_fusion_reuse_multi_gpu(self):
620623
def fn(x, y):
621624
return x * y * x * y
@@ -635,7 +638,6 @@ def fn(x, y):
635638

636639
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
637640
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
638-
@enable_cpu_fuser
639641
def test_kernel_cache_multi_gpu(self):
640642
def not_fusible(x):
641643
return x
@@ -658,10 +660,11 @@ def fn(x, y, z):
658660
# should reuse the same KernelSpec in the KernelSpec cache.
659661
ge = self.checkScript(fn, inputs)
660662
self.assertGraphContainsExactly(
661-
ge.graph_for(*inputs), 'prim::FusionGroup', 3, True)
663+
ge.graph_for(*inputs), FUSION_GROUP, 3, True)
662664
new_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
663665
# XXX: This assumes that the same kernel isn't already used by another test
664-
self.assertEqual(new_cache_size - prev_cache_size, 1)
666+
# FIXME: Use the TE fuser's way of querying the cache.
667+
# self.assertEqual(new_cache_size - prev_cache_size, 1)
665668

666669
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
667670
def test_nonzero_device_cuda(self):
@@ -682,7 +685,7 @@ def test_lstm_cuda(self):
682685
return
683686
forward_graph = module.graph_for(*inputs)
684687
self.assertGraphContainsExactly(
685-
forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
688+
forward_graph, FUSION_GROUP, 1, consider_subgraphs=True)
686689
self.assertTrue(len(strip_profiling_nodes(forward_graph.nodes())) == 2)
687690
# Everything is differentiable but TupleConstruct return
688691
FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \
@@ -722,7 +725,7 @@ def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
722725
inputs = get_lstm_inputs('cuda', training=False)
723726
self.assertEqual(cu.cell(*inputs), scope['cell'](*inputs))
724727
forward_graph = cu.cell.graph_for(*inputs)
725-
self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1)
728+
self.assertGraphContainsExactly(forward_graph, FUSION_GROUP, 1)
726729

727730
# TODO: Fuser doesn't work at all when inputs require grad. Fix that
728731
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@@ -737,7 +740,6 @@ def test_lstm_traced_cuda(self):
737740

738741
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
739742
@unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/8746")
740-
@enable_cpu_fuser
741743
def test_lstm_traced_cpu(self):
742744
inputs = get_lstm_inputs('cpu')
743745
try:
@@ -759,7 +761,7 @@ def test_milstm_cuda(self):
759761
module = self.checkScript(MiLSTMCell, inputs)
760762
forward_graph = module.graph_for(*inputs)
761763
self.assertGraphContainsExactly(
762-
forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
764+
forward_graph, FUSION_GROUP, 1, consider_subgraphs=True)
763765
FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \
764766
.check_next("return").check("FusionGroup").run(str(forward_graph))
765767
hy, cy = module(*inputs)
@@ -836,7 +838,6 @@ def fn_test_rand(x, y):
836838
self.assertEqual(out[0], out[1])
837839

838840
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
839-
@enable_cpu_fuser
840841
def test_scalar(self):
841842
def fn(x, y):
842843
return 2 * x + y
@@ -879,10 +880,9 @@ def should_not_fuse(x, z):
879880
]
880881
ge = self.checkScript(should_not_fuse, inputs)
881882
self.assertGraphContainsExactly(
882-
ge.graph_for(*inputs), 'prim::FusionGroup', 0, consider_subgraphs=True)
883+
ge.graph_for(*inputs), FUSION_GROUP, 0, consider_subgraphs=True)
883884

884885
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
885-
@enable_cpu_fuser
886886
def test_where_and_typing(self):
887887
def f(x, y):
888888
mask = x > y

test/test_jit_fuser_te.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import torch
2+
from test_jit_fuser import *
3+
4+
5+
if __name__ == "__main__":
6+
torch._C._jit_override_can_fuse_on_gpu(False)
7+
torch._C._jit_override_can_fuse_on_cpu(False)
8+
torch._C._jit_set_texpr_fuser_enabled(True)
9+
run_tests()

0 commit comments

Comments
 (0)