Skip to content

Commit 8a4d1e2

Browse files
jataylodnikolaev-amd
authored andcommitted
Add skipIfRocmArch decorator for Navi skips (#1356)
1 parent d98149c commit 8a4d1e2

File tree

4 files changed

+64
-1
lines changed

4 files changed

+64
-1
lines changed

test/inductor/test_cuda_repro.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
freeze_rng_state,
3232
IS_FBCODE,
3333
skipIfRocm,
34+
skipIfRocmArch,
3435
TEST_WITH_ASAN,
3536
)
3637
from torch.testing._internal.inductor_utils import skipCUDAIf
@@ -52,7 +53,7 @@
5253
sys.exit(0)
5354
raise
5455

55-
56+
NAVI_ARCH = ("gfx1100", "gfx1101") # Used for navi exclusive skips on ROCm
5657
TestCase = test_torchinductor.TestCase
5758
ToTuple = test_torchinductor.ToTuple
5859
check_model_cuda = test_torchinductor.check_model_cuda
@@ -336,6 +337,7 @@ def foo(x):
336337
out_ref.add_(2)
337338
# self.assertEqual(out_ref, out)
338339

340+
@skipIfRocmArch(NAVI_ARCH)
339341
def test_accuracy_issue1(self):
340342
class Repro(torch.nn.Module):
341343
def __init__(self) -> None:
@@ -372,6 +374,7 @@ def forward(self, start_positions: torch.Tensor, x: torch.Tensor):
372374
assert same_two_models(mod, opt_mod, args), "Dynamo failed"
373375

374376
@config.patch(allow_buffer_reuse=False)
377+
@skipIfRocmArch(NAVI_ARCH)
375378
def test_issue103461(self):
376379
def forward(add_1):
377380
var_mean = torch.ops.aten.var_mean.correction(
@@ -870,6 +873,7 @@ def forward(self, x):
870873
res2 = jit_func(x)
871874
self.assertEqual(res1, res2)
872875

876+
@skipIfRocmArch(NAVI_ARCH)
873877
def test_issue103481(self):
874878
def fn(x, y):
875879
# NOTE: 6 dimensions is important! does not fail for 5 dimensions

test/inductor/test_torchinductor.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
skipIfWindows,
8686
skipIfXpu,
8787
subtest,
88+
skipIfRocmArch,
8889
TEST_WITH_ASAN,
8990
TEST_WITH_ROCM,
9091
)
@@ -119,6 +120,10 @@
119120

120121

121122
HAS_AVX2 = "fbgemm" in torch.backends.quantized.supported_engines
123+
_desired_test_bases = get_desired_device_type_test_bases()
124+
RUN_CPU = any(getattr(x, "device_type", "") == "cpu" for x in _desired_test_bases)
125+
RUN_GPU = any(getattr(x, "device_type", "") == GPU_TYPE for x in _desired_test_bases)
126+
NAVI_ARCH = ("gfx1100", "gfx1101") # Used for navi exclusive skips on ROCm
122127

123128
aten = torch.ops.aten
124129

@@ -1794,6 +1799,7 @@ def fn(x):
17941799
# make sure things also work if they aren't unrolled
17951800
self.common(fn, (torch.randn(8, 3),))
17961801

1802+
@skipIfRocmArch(NAVI_ARCH)
17971803
def test_multilayer_sum_low_prec(self):
17981804
# fp16 nyi for cpu
17991805
if self.device == "cpu":
@@ -1804,6 +1810,7 @@ def fn(a):
18041810

18051811
self.common(fn, ((torch.rand((10, 3, 352, 352), dtype=torch.float16),)))
18061812

1813+
@skipIfRocmArch(NAVI_ARCH)
18071814
def test_multilayer_prime_size(self):
18081815
def fn(a):
18091816
return torch.max(a), torch.sum(a)
@@ -1815,6 +1822,7 @@ def fn(a):
18151822

18161823
@skip_if_gpu_halide
18171824
@skipCPUIf(IS_MACOS, "fails on macos")
1825+
@skipIfRocmArch(NAVI_ARCH)
18181826
def test_multilayer_var(self):
18191827
def fn(a):
18201828
return torch.var(a)
@@ -2966,6 +2974,7 @@ def fn(a, b):
29662974
self.common(fn, (torch.randn(8, 8), torch.randn(8, 8)))
29672975

29682976
@skip_if_halide # only 32-bit indexing
2977+
@skipIfRocmArch(NAVI_ARCH)
29692978
def test_large_tensor_reduction(self):
29702979
if not _has_sufficient_memory(self.device, 4.5 * 1024**3): # 4.5 GiB
29712980
raise unittest.SkipTest("insufficient memory")
@@ -2987,6 +2996,7 @@ def fn(a):
29872996
self.assertEqual(actual, expect)
29882997

29892998
@skip_if_gpu_halide # only 32-bit indexing
2999+
@skipIfRocmArch(NAVI_ARCH)
29903000
def test_large_broadcast_reduction(self):
29913001
if self.device == "cpu":
29923002
raise unittest.SkipTest("Fails on CPU")
@@ -4148,6 +4158,7 @@ def test_conv2d_channels_last(self):
41484158
check_lowp=False,
41494159
)
41504160

4161+
@skipIfRocmArch(NAVI_ARCH)
41514162
def test_conv2d_backward_channels_last(self):
41524163
def fn(grad_output, inp, weight):
41534164
convolution_backward_8 = torch.ops.aten.convolution_backward.default(
@@ -4932,6 +4943,7 @@ def fn(x, y):
49324943
self.assertEqual(c.stride()[2], 1)
49334944

49344945
@skip_if_gpu_halide
4946+
@skipIfRocmArch(NAVI_ARCH)
49354947
def test_std(self):
49364948
def fn(x):
49374949
return (
@@ -4974,6 +4986,7 @@ def test_batch_norm_2d(self):
49744986

49754987
# From yolov3
49764988
@with_tf32_off
4989+
@skipIfRocmArch(NAVI_ARCH)
49774990
def test_batch_norm_2d_2(self):
49784991
if self.device == "cpu":
49794992
raise unittest.SkipTest(f"requires {GPU_TYPE}")
@@ -5120,6 +5133,7 @@ def fn(dist, angle):
51205133
self.common(fn, (*inp,))
51215134

51225135
@skip_if_gpu_halide # incorrect result on CUDA
5136+
@skipIfRocmArch(NAVI_ARCH)
51235137
def test_cauchy(self):
51245138
def fn(x, y):
51255139
return torch.sum(1 / (torch.unsqueeze(x, -1) - y))
@@ -6520,6 +6534,7 @@ def fn(a):
65206534
y = fn_compiled(x)
65216535
self.assertTrue(y is not x)
65226536

6537+
@skipIfRocmArch(NAVI_ARCH)
65236538
def test_l1_loss(self):
65246539
def fn(a, b):
65256540
return torch.nn.functional.l1_loss(a, b), torch.nn.functional.mse_loss(a, b)
@@ -6920,6 +6935,7 @@ def fn(x):
69206935
fn, (torch.tensor([1, float("inf"), 2, float("-inf"), float("nan")]),)
69216936
)
69226937

6938+
@skipIfRocmArch(NAVI_ARCH)
69236939
def test_any(self):
69246940
def fn(x):
69256941
return (
@@ -7686,6 +7702,8 @@ def fn(a, dim, index, b, reduce):
76867702
)
76877703

76887704
@skip_if_gpu_halide
7705+
# issue #1150
7706+
@skipIfRocmArch(NAVI_ARCH)
76897707
def test_dense_mask_index(self):
76907708
r"""
76917709
There will be a little difference for reduce order between aten and inductor
@@ -8693,6 +8711,7 @@ def fn(a, b):
86938711
b = torch.rand(2, 2, 1, 4, 1).int()
86948712
self.common(fn, (a, b))
86958713

8714+
@skipIfRocmArch(NAVI_ARCH)
86968715
def test_argmax_argmin1(self):
86978716
def fn(x):
86988717
return (aten.argmax(x), aten.argmin(x))
@@ -8704,6 +8723,7 @@ def fn(x):
87048723
],
87058724
)
87068725

8726+
@skipIfRocmArch(NAVI_ARCH)
87078727
def test_argmax_argmin2(self):
87088728
def fn(x):
87098729
return (
@@ -8715,6 +8735,7 @@ def fn(x):
87158735

87168736
self.common(fn, (torch.randn([144, 144]),))
87178737

8738+
@skipIfRocmArch(NAVI_ARCH)
87188739
def test_argmax_argmin_with_duplicates(self):
87198740
def fn(x):
87208741
return (
@@ -8737,6 +8758,7 @@ def fn(x):
87378758
self.common(fn, (t1,))
87388759

87398760
@skip_if_halide # nan behavior
8761+
@skipIfRocmArch(NAVI_ARCH)
87408762
def test_argmax_argmin_with_nan(self):
87418763
def fn(x):
87428764
return (
@@ -8860,6 +8882,7 @@ def fn(x):
88608882
],
88618883
)
88628884

8885+
@skipIfRocmArch(NAVI_ARCH)
88638886
def test_tmp_not_defined_issue1(self):
88648887
def forward(
88658888
primals_3,
@@ -9259,6 +9282,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
92599282
else:
92609283
self.assertEqual(len(inps), 0)
92619284

9285+
@skipIfRocmArch(NAVI_ARCH)
92629286
def test_dtype_mismatch_issue(self):
92639287
def fn(x):
92649288
attn = torch.nn.functional.pad(x, [0, 1])
@@ -12349,6 +12373,7 @@ def test_rnn_compile_safe(self):
1234912373

1235012374
class NanCheckerTest(TestCase):
1235112375
@config.patch("nan_asserts", True)
12376+
@skipIfRocmArch(NAVI_ARCH)
1235212377
def test_nan_checker_pass(self):
1235312378
def f(x):
1235412379
return torch.softmax(x, dim=-1)

test/inductor/test_torchinductor_opinfo.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from torch.testing._internal.common_methods_invocations import op_db, skipOps
3131
from torch.testing._internal.common_utils import (
3232
dtype_abbrs,
33+
IS_NAVI,
3334
IS_MACOS,
3435
IS_X86,
3536
skipCUDAMemoryLeakCheckIf,
@@ -203,6 +204,19 @@ def format_op(op):
203204
# Tensors are not alike
204205
inductor_skips["cuda"]["logcumsumexp"] = {f32}
205206
inductor_skips["cuda"]["special.modified_bessel_i1"] = {f64}
207+
if IS_NAVI:
208+
inductor_skips["cuda"]["aminmax"] = {b8, f16, f32, f64, i32, i64}
209+
inductor_skips["cuda"]["dist"] = {b8, f16, f32, f64, i32, i64}
210+
inductor_skips["cuda"]["kron"] = {b8, f16, f32, f64, i32, i64}
211+
inductor_skips["cuda"]["masked.std"] = {b8, f16, f32, f64, i32, i64}
212+
inductor_skips["cuda"]["masked.var"] = {b8, f16, f32, f64, i32, i64}
213+
inductor_skips["cuda"][("max", "reduction_no_dim")] = {b8, f16, f32, f64, i32, i64}
214+
inductor_skips["cuda"][("min", "reduction_no_dim")] = {b8, f16, f32, f64, i32, i64}
215+
inductor_skips["cuda"]["nn.functional.conv_transpose3d"] = {b8, f16, f32, f64, i32, i64}
216+
inductor_skips["cuda"]["std"] = {b8, f16, f32, f64, i32, i64}
217+
inductor_skips["cuda"]["std_mean"] = {b8, f16, f32, f64, i32, i64}
218+
inductor_skips["cuda"]["var"] = {b8, f16, f32, f64, i32, i64}
219+
inductor_skips["cuda"]["var_mean"] = {b8, f16, f32, f64, i32, i64}
206220

207221
inductor_expected_failures_single_sample = defaultdict(dict)
208222

torch/testing/_internal/common_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,6 +1279,13 @@ def printErrors(self) -> None:
12791279
IS_X86 = platform.machine() in ('x86_64', 'i386')
12801280
IS_ARM64 = platform.machine() in ('arm64', 'aarch64')
12811281

1282+
IS_NAVI=False
1283+
if torch.cuda.is_available():
1284+
prop = torch.cuda.get_device_properties(0)
1285+
gfx_arch = prop.gcnArchName.split(":")[0]
1286+
if gfx_arch in ["gfx1100", "gfx1101", "gfx1102"]:
1287+
IS_NAVI = True
1288+
12821289
def is_avx512_vnni_supported():
12831290
if sys.platform != 'linux':
12841291
return False
@@ -1754,6 +1761,19 @@ def wrapper(*args, **kwargs):
17541761
return dec_fn(func)
17551762
return dec_fn
17561763

1764+
def skipIfRocmArch(arch: Tuple[str, ...]):
1765+
def dec_fn(fn):
1766+
@wraps(fn)
1767+
def wrap_fn(self, *args, **kwargs):
1768+
if TEST_WITH_ROCM:
1769+
prop = torch.cuda.get_device_properties(0)
1770+
if prop.gcnArchName.split(":")[0] in arch:
1771+
reason = f"skipIfRocm: test skipped on {arch}"
1772+
raise unittest.SkipTest(reason)
1773+
return fn(self, *args, **kwargs)
1774+
return wrap_fn
1775+
return dec_fn
1776+
17571777
def runOnRocm(fn):
17581778
@wraps(fn)
17591779
def wrapper(*args, **kwargs):

0 commit comments

Comments
 (0)