22
22
torch ._C ._jit_set_profiling_executor (True )
23
23
torch ._C ._jit_set_profiling_mode (True )
24
24
25
+ FUSION_GROUP = 'tensorexpr::Group'
25
26
26
27
def strip_profiling_nodes (nodes ):
27
28
profiling_opcodes = set (['prim::BailoutTemplate' , 'prim::BailOut' ])
@@ -57,11 +58,11 @@ def assertAllFused(self, graph, except_for=()):
57
58
self .assertEqual (len (diff_graphs ), 1 )
58
59
graph = diff_graphs [0 ].g ('Subgraph' )
59
60
60
- allowed_nodes = {'prim::Constant' , 'prim::FusionGroup' , 'prim::BailoutTemplate' ,
61
+ allowed_nodes = {'prim::Constant' , FUSION_GROUP , 'prim::BailoutTemplate' ,
61
62
'prim::BailOut' , 'prim::TupleConstruct' } | set (except_for )
62
63
self .assertTrue (all (node .kind () in allowed_nodes for node in graph .nodes ()),
63
64
'got {}' .format (graph ))
64
- self .assertTrue ([node .kind () for node in graph .nodes ()].count ('prim::FusionGroup' ) == 1 )
65
+ self .assertTrue ([node .kind () for node in graph .nodes ()].count (FUSION_GROUP ) == 1 )
65
66
66
67
def _test_fused_abs (self , device = 'cpu' ):
67
68
def func (x ):
@@ -72,25 +73,31 @@ def func(x):
72
73
self .assertAllFused (scripted .graph_for (a ))
73
74
74
75
@unittest .skipIf (IS_SANDCASTLE , "NYI: fuser CPU support for Sandcastle" )
75
- @enable_cpu_fuser
76
76
def test_abs_cpu (self ):
77
77
self ._test_fused_abs ()
78
78
79
79
@unittest .skipIf (not RUN_CUDA , "requires CUDA" )
80
80
def test_abs_cuda (self ):
81
81
self ._test_fused_abs (device = "cuda" )
82
82
83
- @unittest .skipIf (not RUN_CUDA , "requires CUDA" )
84
- def test_zero_element_tensors (self ):
83
+ def _test_zero_element_tensors (self , device = "cpu" ):
85
84
def decode (sin_t , cos_t ):
86
85
theta = torch .atan2 (sin_t .float (), cos_t .float ())
87
86
return theta
88
87
89
- sin = torch .zeros (0 , device = "cuda" )
90
- cos = torch .zeros (0 , device = "cuda" )
88
+ sin = torch .zeros (0 , device = device )
89
+ cos = torch .zeros (0 , device = device )
91
90
inputs = [sin , cos ]
92
91
ge = self .checkScript (decode , inputs )
93
92
93
+ @unittest .skipIf (not RUN_CUDA , "requires CUDA" )
94
+ def test_zero_element_tensors_cuda (self ):
95
+ self ._test_zero_element_tensors (device = "cuda" )
96
+
97
+ @unittest .skipIf (IS_SANDCASTLE , "NYI: fuser CPU support for Sandcastle" )
98
+ def test_zero_element_tensors_cpu (self ):
99
+ self ._test_zero_element_tensors (device = "cpu" )
100
+
94
101
@unittest .skipIf (not RUN_CUDA , "fuser requires CUDA" )
95
102
def test_arg_configurations_smoke_cuda (self ):
96
103
# A smoke test to make sure we won't use the same kernel for contiguous
@@ -216,7 +223,6 @@ def chunk_4_last(x):
216
223
self .checkScript (fn , [tensor ])
217
224
218
225
@unittest .skipIf (IS_SANDCASTLE , "NYI: fuser CPU support for Sandcastle" )
219
- @enable_cpu_fuser
220
226
def test_chunk_correctness (self ):
221
227
return self ._test_chunk_correctness (self , 'cpu' )
222
228
@@ -235,7 +241,7 @@ def f(x, y):
235
241
236
242
ge = self .checkTrace (f , (x , y ))
237
243
graph = ge .graph_for (x , y )
238
- FileCheck ().check ("broadcast_tensors" ).check ('with prim::FusionGroup_ ' ) \
244
+ FileCheck ().check ("broadcast_tensors" ).check ('with ' + FUSION_GROUP + '_ ' ) \
239
245
.check_count ('ConstantChunk' , 2 , exactly = True ).run (str (graph ))
240
246
241
247
@unittest .skipIf (not RUN_CUDA , "fuser requires CUDA" )
@@ -256,7 +262,7 @@ def func2(x):
256
262
for func in [func1 , func2 ]:
257
263
module = self .checkScript (func , inputs )
258
264
forward_graph = module .graph_for (* inputs )
259
- self .assertGraphContainsExactly (forward_graph , 'prim::FusionGroup' , 1 )
265
+ self .assertGraphContainsExactly (forward_graph , FUSION_GROUP , 1 )
260
266
fusion_group = list (forward_graph .nodes ())[- 1 ]
261
267
self .assertEqual (len (list (fusion_group .inputs ())), 1 )
262
268
@@ -498,7 +504,7 @@ def test_norm_decompose(nm, in_opt_graph, not_in_opt_graph, in_fusegraph):
498
504
self .assertNotIn (node_not_in_graph , rep )
499
505
self .assertIn (node_not_in_graph , rep_noopt )
500
506
501
- fusion_groups = [node for node in graph .nodes () if node .kind () == 'prim::FusionGroup' ]
507
+ fusion_groups = [node for node in graph .nodes () if node .kind () == FUSION_GROUP ]
502
508
self .assertEqual (len (fusion_groups ), 1 )
503
509
fused_graph = str (fusion_groups [0 ].g ('Subgraph' ))
504
510
for node_in_fusegraph in in_fusegraph :
@@ -549,7 +555,6 @@ def fn_test_scalar_arg_requires_grad(x, p):
549
555
550
556
@unittest .skipIf (IS_SANDCASTLE , "NYI: fuser CPU support for Sandcastle" )
551
557
@unittest .skip ("deduplicating introduces aliasing in backward graph's outputs" )
552
- @enable_cpu_fuser
553
558
def test_fuser_deduplication (self ):
554
559
# See that fusion kernel outputs are deduplicated when removing _grad_sum_to_size in the fuser's compilation
555
560
# see the discussion in PR #14957.
@@ -571,7 +576,6 @@ def f(x, y):
571
576
self .assertEqual (ga2 .data_ptr (), gb2 .data_ptr ())
572
577
573
578
@unittest .skipIf (IS_SANDCASTLE , "NYI: fuser CPU support for Sandcastle" )
574
- @enable_cpu_fuser
575
579
@unittest .skip ("temporarily disabled because fusion was restricted in fixing #22833" )
576
580
def test_fuser_iou (self ):
577
581
# This checks if most of Intersection over Union is fused.
@@ -615,7 +619,6 @@ def iou(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2):
615
619
616
620
@unittest .skipIf (not RUN_CUDA , "fuser requires CUDA" )
617
621
@unittest .skipIf (not RUN_CUDA_MULTI_GPU , "needs non-zero device" )
618
- @enable_cpu_fuser
619
622
def test_fusion_reuse_multi_gpu (self ):
620
623
def fn (x , y ):
621
624
return x * y * x * y
@@ -635,7 +638,6 @@ def fn(x, y):
635
638
636
639
@unittest .skipIf (not RUN_CUDA , "fuser requires CUDA" )
637
640
@unittest .skipIf (not RUN_CUDA_MULTI_GPU , "needs non-zero device" )
638
- @enable_cpu_fuser
639
641
def test_kernel_cache_multi_gpu (self ):
640
642
def not_fusible (x ):
641
643
return x
@@ -658,10 +660,11 @@ def fn(x, y, z):
658
660
# should reuse the same KernelSpec in the KernelSpec cache.
659
661
ge = self .checkScript (fn , inputs )
660
662
self .assertGraphContainsExactly (
661
- ge .graph_for (* inputs ), 'prim::FusionGroup' , 3 , True )
663
+ ge .graph_for (* inputs ), FUSION_GROUP , 3 , True )
662
664
new_cache_size = torch ._C ._jit_debug_fuser_num_cached_kernel_specs ()
663
665
# XXX: This assumes that the same kernel isn't already used by another test
664
- self .assertEqual (new_cache_size - prev_cache_size , 1 )
666
+ # FIXME: Use the TE fuser's way of querying the cache.
667
+ # self.assertEqual(new_cache_size - prev_cache_size, 1)
665
668
666
669
@unittest .skipIf (not RUN_CUDA_MULTI_GPU , "needs non-zero device" )
667
670
def test_nonzero_device_cuda (self ):
@@ -682,7 +685,7 @@ def test_lstm_cuda(self):
682
685
return
683
686
forward_graph = module .graph_for (* inputs )
684
687
self .assertGraphContainsExactly (
685
- forward_graph , 'prim::FusionGroup' , 1 , consider_subgraphs = True )
688
+ forward_graph , FUSION_GROUP , 1 , consider_subgraphs = True )
686
689
self .assertTrue (len (strip_profiling_nodes (forward_graph .nodes ())) == 2 )
687
690
# Everything is differentiable but TupleConstruct return
688
691
FileCheck ().check ("DifferentiableGraph" ).check_next ("TupleConstruct" ) \
@@ -722,7 +725,7 @@ def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
722
725
inputs = get_lstm_inputs ('cuda' , training = False )
723
726
self .assertEqual (cu .cell (* inputs ), scope ['cell' ](* inputs ))
724
727
forward_graph = cu .cell .graph_for (* inputs )
725
- self .assertGraphContainsExactly (forward_graph , 'prim::FusionGroup' , 1 )
728
+ self .assertGraphContainsExactly (forward_graph , FUSION_GROUP , 1 )
726
729
727
730
# TODO: Fuser doesn't work at all when inputs require grad. Fix that
728
731
@unittest .skipIf (not RUN_CUDA , "fuser requires CUDA" )
@@ -737,7 +740,6 @@ def test_lstm_traced_cuda(self):
737
740
738
741
@unittest .skipIf (IS_SANDCASTLE , "NYI: fuser CPU support for Sandcastle" )
739
742
@unittest .skip ("Test is flaky, see https://github.com/pytorch/pytorch/issues/8746" )
740
- @enable_cpu_fuser
741
743
def test_lstm_traced_cpu (self ):
742
744
inputs = get_lstm_inputs ('cpu' )
743
745
try :
@@ -759,7 +761,7 @@ def test_milstm_cuda(self):
759
761
module = self .checkScript (MiLSTMCell , inputs )
760
762
forward_graph = module .graph_for (* inputs )
761
763
self .assertGraphContainsExactly (
762
- forward_graph , 'prim::FusionGroup' , 1 , consider_subgraphs = True )
764
+ forward_graph , FUSION_GROUP , 1 , consider_subgraphs = True )
763
765
FileCheck ().check ("DifferentiableGraph" ).check_next ("TupleConstruct" ) \
764
766
.check_next ("return" ).check ("FusionGroup" ).run (str (forward_graph ))
765
767
hy , cy = module (* inputs )
@@ -836,7 +838,6 @@ def fn_test_rand(x, y):
836
838
self .assertEqual (out [0 ], out [1 ])
837
839
838
840
@unittest .skipIf (IS_SANDCASTLE , "NYI: fuser CPU support for Sandcastle" )
839
- @enable_cpu_fuser
840
841
def test_scalar (self ):
841
842
def fn (x , y ):
842
843
return 2 * x + y
@@ -879,10 +880,9 @@ def should_not_fuse(x, z):
879
880
]
880
881
ge = self .checkScript (should_not_fuse , inputs )
881
882
self .assertGraphContainsExactly (
882
- ge .graph_for (* inputs ), 'prim::FusionGroup' , 0 , consider_subgraphs = True )
883
+ ge .graph_for (* inputs ), FUSION_GROUP , 0 , consider_subgraphs = True )
883
884
884
885
@unittest .skipIf (IS_SANDCASTLE , "NYI: fuser CPU support for Sandcastle" )
885
- @enable_cpu_fuser
886
886
def test_where_and_typing (self ):
887
887
def f (x , y ):
888
888
mask = x > y
0 commit comments