79
79
unittest .skipIf , not HAS_MULTIGPU , "requires multiple cuda devices"
80
80
)
81
81
82
- torch ._inductor .config .triton .autotune = False # too slow
82
+ torch ._inductor .config .triton .autotune_pointwise = False # too slow
83
83
84
84
85
85
# For OneDNN bf16 path, OneDNN requires the cpu has intel avx512 with avx512bw,
@@ -2505,76 +2505,6 @@ def fn(x, y):
2505
2505
self .assertEqual (a .stride (), c .stride ())
2506
2506
self .assertEqual (c .stride ()[2 ], 1 )
2507
2507
2508
- @requires_cuda ()
2509
- @patch .object (config .triton , "convolution" , "triton" )
2510
- @patch .object (config .triton , "dense_indexing" , "True" )
2511
- def test_triton_conv (self ):
2512
- @torch ._dynamo .optimize ("inductor" , nopython = True )
2513
- def triton_conv (
2514
- x ,
2515
- w ,
2516
- bias ,
2517
- stride ,
2518
- padding ,
2519
- dilation ,
2520
- groups ,
2521
- ):
2522
- y = torch .conv2d (x , w , bias , stride , padding , dilation , groups )
2523
- return y
2524
-
2525
- stride , padding , dilation , groups = (1 , 1 ), (0 , 0 ), (1 , 1 ), 1
2526
- dtype = torch .float32
2527
- x = torch .randn ((32 , 128 , 32 , 32 ), dtype = dtype , device = self .device )
2528
- w = torch .randn ((32 , 128 , 1 , 1 ), dtype = dtype , device = self .device )
2529
- bias = torch .randn ((32 ), dtype = dtype , device = self .device )
2530
-
2531
- y = triton_conv (x , w , bias , stride , padding , dilation , groups )
2532
- y_correct = torch .conv2d (x , w , bias , stride , padding , dilation , groups )
2533
- self .assertTrue (same (y , y_correct , cos_similarity = True , tol = 0.1 ))
2534
-
2535
- @requires_cuda ()
2536
- @patch .object (config .triton , "convolution" , "autotune" )
2537
- @patch .object (config .triton , "dense_indexing" , "True" )
2538
- def test_conv_autotune (self ):
2539
- @torch ._dynamo .optimize ("inductor" , nopython = True )
2540
- def triton_conv (
2541
- x ,
2542
- w ,
2543
- bias ,
2544
- stride ,
2545
- padding ,
2546
- dilation ,
2547
- groups ,
2548
- ):
2549
- y = torch .conv2d (x , w , bias , stride , padding , dilation , groups )
2550
- return y
2551
-
2552
- stride , padding , dilation , groups = (1 , 1 ), (0 , 0 ), (1 , 1 ), 1
2553
- dtype = torch .float32
2554
- x = torch .randn ((32 , 128 , 32 , 32 ), dtype = dtype , device = self .device )
2555
- w = torch .randn ((32 , 128 , 1 , 1 ), dtype = dtype , device = self .device )
2556
- bias = torch .randn ((32 ), dtype = dtype , device = self .device )
2557
-
2558
- y = triton_conv (x , w , bias , stride , padding , dilation , groups )
2559
- y_correct = torch .conv2d (x , w , bias , stride , padding , dilation , groups )
2560
- self .assertTrue (same (y , y_correct , cos_similarity = True , tol = 0.1 ))
2561
-
2562
- @patch .object (config .triton , "mm" , "triton" )
2563
- def test_triton_mm2 (self ):
2564
- @torch ._dynamo .optimize ("inductor" , nopython = True )
2565
- def fn (x , y ):
2566
- return torch .relu (torch .mm (x , y ))
2567
-
2568
- N = 1024
2569
- a = torch .randn ([N , N ], device = self .device , dtype = torch .float32 )
2570
- b = torch .randn ([N , N ], device = self .device , dtype = torch .float32 )
2571
- c1 = torch .relu (torch .mm (a , b ))
2572
- torch ._inductor .metrics .reset ()
2573
- c = fn (a , b )
2574
- assert torch .allclose (c1 , c , atol = 1e-3 , rtol = 1e-3 )
2575
- if self .device == "cuda" :
2576
- assert torch ._inductor .metrics .generated_kernel_count == 1
2577
-
2578
2508
def test_std (self ):
2579
2509
def fn (x ):
2580
2510
return (
@@ -4560,12 +4490,6 @@ def fn(a, b):
4560
4490
)
4561
4491
expected_kernel = 0
4562
4492
# codegen mm kernel from template
4563
- if config .triton .mm != "aten" and self .device == "cuda" :
4564
- expected_kernel = 1
4565
- if config .triton .mm == "autotune" :
4566
- self .assertLessEqual (
4567
- torch ._inductor .metrics .generated_kernel_count , expected_kernel
4568
- )
4569
4493
self .assertEqual (
4570
4494
torch ._inductor .metrics .generated_kernel_count , expected_kernel
4571
4495
)
@@ -4641,15 +4565,6 @@ def run(x):
4641
4565
result .sum ().backward ()
4642
4566
4643
4567
expected_kernel = 4
4644
- if config .triton .mm != "aten" and self .device == "cuda" :
4645
- # fwd: 2 * (mm+dropout) kernels = 2 kernels
4646
- # bwd: dropout + (mm) + 2 * (mm+dropout) kernels = 4 kernels
4647
- # expect 2 + 4 = 6 kernels
4648
- expected_kernel = 6
4649
- if config .triton .mm == "autotune" :
4650
- self .assertLessEqual (
4651
- torch ._inductor .metrics .generated_kernel_count , expected_kernel
4652
- )
4653
4568
self .assertEqual (
4654
4569
torch ._inductor .metrics .generated_kernel_count , expected_kernel
4655
4570
)
@@ -4979,7 +4894,6 @@ def fn(x, y):
4979
4894
inputs = (inputs [1 ], inputs [0 ])
4980
4895
self .assertTrue (same (opt (* inputs ), fn (* inputs )))
4981
4896
4982
- @patch .object (config .triton , "mm" , "aten" )
4983
4897
def test_list_clearing (self ):
4984
4898
4985
4899
if self .device == "cpu" :
@@ -5685,7 +5599,7 @@ def forward(self, view, reshape_2):
5685
5599
res = opt_mod (* args )
5686
5600
self .assertTrue (same (ref , res ))
5687
5601
5688
- @patch .object (config .triton , "autotune " , True )
5602
+ @patch .object (config .triton , "autotune_pointwise " , True )
5689
5603
def test_inplace_add_alpha_autotune (self ):
5690
5604
def fn (x , y ):
5691
5605
aten .add_ .Tensor (x , y , alpha = 0.55 )
@@ -5703,7 +5617,7 @@ def fn(x, y):
5703
5617
fn_compiled ([x3 , y ])
5704
5618
assert same (x2 , x3 )
5705
5619
5706
- @patch .object (config .triton , "autotune " , True )
5620
+ @patch .object (config .triton , "autotune_pointwise " , True )
5707
5621
def test_inplace_buffer_autotune (self ):
5708
5622
def foo (x , y , z ):
5709
5623
a = x @ y
0 commit comments