85
85
skipIfWindows ,
86
86
skipIfXpu ,
87
87
subtest ,
88
+ skipIfRocmArch ,
88
89
TEST_WITH_ASAN ,
89
90
TEST_WITH_ROCM ,
90
91
)
119
120
120
121
121
122
HAS_AVX2 = "fbgemm" in torch .backends .quantized .supported_engines
123
+ _desired_test_bases = get_desired_device_type_test_bases ()
124
+ RUN_CPU = any (getattr (x , "device_type" , "" ) == "cpu" for x in _desired_test_bases )
125
+ RUN_GPU = any (getattr (x , "device_type" , "" ) == GPU_TYPE for x in _desired_test_bases )
126
+ NAVI_ARCH = ("gfx1100" , "gfx1101" ) # Used for navi exclusive skips on ROCm
122
127
123
128
aten = torch .ops .aten
124
129
@@ -1794,6 +1799,7 @@ def fn(x):
1794
1799
# make sure things also work if they aren't unrolled
1795
1800
self .common (fn , (torch .randn (8 , 3 ),))
1796
1801
1802
+ @skipIfRocmArch (NAVI_ARCH )
1797
1803
def test_multilayer_sum_low_prec (self ):
1798
1804
# fp16 nyi for cpu
1799
1805
if self .device == "cpu" :
@@ -1804,6 +1810,7 @@ def fn(a):
1804
1810
1805
1811
self .common (fn , ((torch .rand ((10 , 3 , 352 , 352 ), dtype = torch .float16 ),)))
1806
1812
1813
+ @skipIfRocmArch (NAVI_ARCH )
1807
1814
def test_multilayer_prime_size (self ):
1808
1815
def fn (a ):
1809
1816
return torch .max (a ), torch .sum (a )
@@ -1815,6 +1822,7 @@ def fn(a):
1815
1822
1816
1823
@skip_if_gpu_halide
1817
1824
@skipCPUIf (IS_MACOS , "fails on macos" )
1825
+ @skipIfRocmArch (NAVI_ARCH )
1818
1826
def test_multilayer_var (self ):
1819
1827
def fn (a ):
1820
1828
return torch .var (a )
@@ -2966,6 +2974,7 @@ def fn(a, b):
2966
2974
self .common (fn , (torch .randn (8 , 8 ), torch .randn (8 , 8 )))
2967
2975
2968
2976
@skip_if_halide # only 32-bit indexing
2977
+ @skipIfRocmArch (NAVI_ARCH )
2969
2978
def test_large_tensor_reduction (self ):
2970
2979
if not _has_sufficient_memory (self .device , 4.5 * 1024 ** 3 ): # 4.5 GiB
2971
2980
raise unittest .SkipTest ("insufficient memory" )
@@ -2987,6 +2996,7 @@ def fn(a):
2987
2996
self .assertEqual (actual , expect )
2988
2997
2989
2998
@skip_if_gpu_halide # only 32-bit indexing
2999
+ @skipIfRocmArch (NAVI_ARCH )
2990
3000
def test_large_broadcast_reduction (self ):
2991
3001
if self .device == "cpu" :
2992
3002
raise unittest .SkipTest ("Fails on CPU" )
@@ -4148,6 +4158,7 @@ def test_conv2d_channels_last(self):
4148
4158
check_lowp = False ,
4149
4159
)
4150
4160
4161
+ @skipIfRocmArch (NAVI_ARCH )
4151
4162
def test_conv2d_backward_channels_last (self ):
4152
4163
def fn (grad_output , inp , weight ):
4153
4164
convolution_backward_8 = torch .ops .aten .convolution_backward .default (
@@ -4932,6 +4943,7 @@ def fn(x, y):
4932
4943
self .assertEqual (c .stride ()[2 ], 1 )
4933
4944
4934
4945
@skip_if_gpu_halide
4946
+ @skipIfRocmArch (NAVI_ARCH )
4935
4947
def test_std (self ):
4936
4948
def fn (x ):
4937
4949
return (
@@ -4974,6 +4986,7 @@ def test_batch_norm_2d(self):
4974
4986
4975
4987
# From yolov3
4976
4988
@with_tf32_off
4989
+ @skipIfRocmArch (NAVI_ARCH )
4977
4990
def test_batch_norm_2d_2 (self ):
4978
4991
if self .device == "cpu" :
4979
4992
raise unittest .SkipTest (f"requires { GPU_TYPE } " )
@@ -5120,6 +5133,7 @@ def fn(dist, angle):
5120
5133
self .common (fn , (* inp ,))
5121
5134
5122
5135
@skip_if_gpu_halide # incorrect result on CUDA
5136
+ @skipIfRocmArch (NAVI_ARCH )
5123
5137
def test_cauchy (self ):
5124
5138
def fn (x , y ):
5125
5139
return torch .sum (1 / (torch .unsqueeze (x , - 1 ) - y ))
@@ -6520,6 +6534,7 @@ def fn(a):
6520
6534
y = fn_compiled (x )
6521
6535
self .assertTrue (y is not x )
6522
6536
6537
+ @skipIfRocmArch (NAVI_ARCH )
6523
6538
def test_l1_loss (self ):
6524
6539
def fn (a , b ):
6525
6540
return torch .nn .functional .l1_loss (a , b ), torch .nn .functional .mse_loss (a , b )
@@ -6920,6 +6935,7 @@ def fn(x):
6920
6935
fn , (torch .tensor ([1 , float ("inf" ), 2 , float ("-inf" ), float ("nan" )]),)
6921
6936
)
6922
6937
6938
+ @skipIfRocmArch (NAVI_ARCH )
6923
6939
def test_any (self ):
6924
6940
def fn (x ):
6925
6941
return (
@@ -7686,6 +7702,8 @@ def fn(a, dim, index, b, reduce):
7686
7702
)
7687
7703
7688
7704
@skip_if_gpu_halide
7705
+ # issue #1150
7706
+ @skipIfRocmArch (NAVI_ARCH )
7689
7707
def test_dense_mask_index (self ):
7690
7708
r"""
7691
7709
There will be a little difference for reduce order between aten and inductor
@@ -8693,6 +8711,7 @@ def fn(a, b):
8693
8711
b = torch .rand (2 , 2 , 1 , 4 , 1 ).int ()
8694
8712
self .common (fn , (a , b ))
8695
8713
8714
+ @skipIfRocmArch (NAVI_ARCH )
8696
8715
def test_argmax_argmin1 (self ):
8697
8716
def fn (x ):
8698
8717
return (aten .argmax (x ), aten .argmin (x ))
@@ -8704,6 +8723,7 @@ def fn(x):
8704
8723
],
8705
8724
)
8706
8725
8726
+ @skipIfRocmArch (NAVI_ARCH )
8707
8727
def test_argmax_argmin2 (self ):
8708
8728
def fn (x ):
8709
8729
return (
@@ -8715,6 +8735,7 @@ def fn(x):
8715
8735
8716
8736
self .common (fn , (torch .randn ([144 , 144 ]),))
8717
8737
8738
+ @skipIfRocmArch (NAVI_ARCH )
8718
8739
def test_argmax_argmin_with_duplicates (self ):
8719
8740
def fn (x ):
8720
8741
return (
@@ -8737,6 +8758,7 @@ def fn(x):
8737
8758
self .common (fn , (t1 ,))
8738
8759
8739
8760
@skip_if_halide # nan behavior
8761
+ @skipIfRocmArch (NAVI_ARCH )
8740
8762
def test_argmax_argmin_with_nan (self ):
8741
8763
def fn (x ):
8742
8764
return (
@@ -8860,6 +8882,7 @@ def fn(x):
8860
8882
],
8861
8883
)
8862
8884
8885
+ @skipIfRocmArch (NAVI_ARCH )
8863
8886
def test_tmp_not_defined_issue1 (self ):
8864
8887
def forward (
8865
8888
primals_3 ,
@@ -9259,6 +9282,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
9259
9282
else :
9260
9283
self .assertEqual (len (inps ), 0 )
9261
9284
9285
+ @skipIfRocmArch (NAVI_ARCH )
9262
9286
def test_dtype_mismatch_issue (self ):
9263
9287
def fn (x ):
9264
9288
attn = torch .nn .functional .pad (x , [0 , 1 ])
@@ -12349,6 +12373,7 @@ def test_rnn_compile_safe(self):
12349
12373
12350
12374
class NanCheckerTest (TestCase ):
12351
12375
@config .patch ("nan_asserts" , True )
12376
+ @skipIfRocmArch (NAVI_ARCH )
12352
12377
def test_nan_checker_pass (self ):
12353
12378
def f (x ):
12354
12379
return torch .softmax (x , dim = - 1 )
0 commit comments