Skip to content

Commit 0b08ab6

Browse files
jataylopruthvistony
authored andcommitted
Add skipIfRocmArch decorator for Navi skips (#1356)
1 parent acff1c1 commit 0b08ab6

File tree

4 files changed

+73
-2
lines changed

4 files changed

+73
-2
lines changed

test/inductor/test_cuda_repro.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
freeze_rng_state,
2222
IS_FBCODE,
2323
skipIfRocm,
24+
skipIfRocmArch,
2425
TEST_WITH_ASAN,
2526
)
2627

@@ -40,7 +41,7 @@
4041
sys.exit(0)
4142
raise
4243

43-
44+
NAVI_ARCH = ("gfx1100", "gfx1101") # Used for navi exclusive skips on ROCm
4445
TestCase = test_torchinductor.TestCase
4546
ToTuple = test_torchinductor.ToTuple
4647
check_model_cuda = test_torchinductor.check_model_cuda
@@ -305,6 +306,7 @@ def foo(x):
305306
out_ref.add_(2)
306307
# self.assertEqual(out_ref, out)
307308

309+
@skipIfRocmArch(NAVI_ARCH)
308310
def test_accuracy_issue1(self):
309311
class Repro(torch.nn.Module):
310312
def __init__(self):
@@ -341,6 +343,7 @@ def forward(self, start_positions: torch.Tensor, x: torch.Tensor):
341343
assert same_two_models(mod, opt_mod, args), "Dynamo failed"
342344

343345
@config.patch(allow_buffer_reuse=False)
346+
@skipIfRocmArch(NAVI_ARCH)
344347
def test_issue103461(self):
345348
def forward(add_1):
346349
var_mean = torch.ops.aten.var_mean.correction(
@@ -828,6 +831,7 @@ def forward(self, x):
828831
res2 = jit_func(x)
829832
self.assertEqual(res1, res2)
830833

834+
@skipIfRocmArch(NAVI_ARCH)
831835
def test_issue103481(self):
832836
def fn(x, y):
833837
# NOTE: 6 dimensions is important! does not fail for 5 dimensions

test/inductor/test_torchinductor.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
IS_X86,
7272
parametrize,
7373
skipIfRocm,
74-
subtest,
74+
skipIfRocmArch,
7575
TEST_WITH_ASAN,
7676
TEST_WITH_ROCM,
7777
)
@@ -111,6 +111,7 @@
111111
_desired_test_bases = get_desired_device_type_test_bases()
112112
RUN_CPU = any(getattr(x, "device_type", "") == "cpu" for x in _desired_test_bases)
113113
RUN_GPU = any(getattr(x, "device_type", "") == GPU_TYPE for x in _desired_test_bases)
114+
NAVI_ARCH = ("gfx1100", "gfx1101") # Used for navi exclusive skips on ROCm
114115

115116
aten = torch.ops.aten
116117
requires_gpu = functools.partial(unittest.skipIf, not HAS_GPU, "requires gpu")
@@ -1206,6 +1207,7 @@ def fn(x):
12061207
# make sure things also work if they aren't unrolled
12071208
self.common(fn, (torch.randn(8, 3),))
12081209

1210+
@skipIfRocmArch(NAVI_ARCH)
12091211
def test_multilayer_sum_low_prec(self):
12101212
# fp16 nyi for cpu
12111213
if self.device == "cpu":
@@ -1216,6 +1218,7 @@ def fn(a):
12161218

12171219
self.common(fn, ((torch.rand((10, 3, 352, 352), dtype=torch.float16),)))
12181220

1221+
@skipIfRocmArch(NAVI_ARCH)
12191222
def test_multilayer_prime_size(self):
12201223
def fn(a):
12211224
return torch.max(a), torch.sum(a)
@@ -1225,6 +1228,7 @@ def fn(a):
12251228
sample[-1] = 1
12261229
self.common(fn, (sample,))
12271230

1231+
@skipIfRocmArch(NAVI_ARCH)
12281232
def test_multilayer_var(self):
12291233
def fn(a):
12301234
return torch.var(a)
@@ -2063,6 +2067,7 @@ def fn(a, b):
20632067

20642068
self.common(fn, (torch.randn(8, 8), torch.randn(8, 8)))
20652069

2070+
@skipIfRocmArch(NAVI_ARCH)
20662071
def test_large_tensor_reduction(self):
20672072
if not _has_sufficient_memory(self.device, 4.5 * 1024**3): # 4.5 GiB
20682073
raise unittest.SkipTest("insufficient memory")
@@ -2083,6 +2088,7 @@ def fn(a):
20832088
expect = torch.tensor(2, dtype=torch.int8, device=self.device)
20842089
self.assertEqual(actual, expect)
20852090

2091+
@skipIfRocmArch(NAVI_ARCH)
20862092
def test_large_broadcast_reduction(self):
20872093
if self.device == "cpu":
20882094
raise unittest.SkipTest("Fails on CPU")
@@ -3094,6 +3100,7 @@ def test_conv2d_channels_last(self):
30943100
check_lowp=False,
30953101
)
30963102

3103+
@skipIfRocmArch(NAVI_ARCH)
30973104
def test_conv2d_backward_channels_last(self):
30983105
def fn(grad_output, inp, weight):
30993106
convolution_backward_8 = torch.ops.aten.convolution_backward.default(
@@ -3839,6 +3846,7 @@ def fn(x, y):
38393846
self.assertEqual(a.stride(), c.stride())
38403847
self.assertEqual(c.stride()[2], 1)
38413848

3849+
@skipIfRocmArch(NAVI_ARCH)
38423850
def test_std(self):
38433851
def fn(x):
38443852
return (
@@ -3881,6 +3889,7 @@ def test_batch_norm_2d(self):
38813889

38823890
# From yolov3
38833891
@with_tf32_off
3892+
@skipIfRocmArch(NAVI_ARCH)
38843893
def test_batch_norm_2d_2(self):
38853894
if self.device == "cpu":
38863895
raise unittest.SkipTest(f"requires {GPU_TYPE}")
@@ -4016,6 +4025,7 @@ def fn(x):
40164025

40174026
self.common(fn, (x,))
40184027

4028+
@skipIfRocmArch(NAVI_ARCH)
40194029
def test_cauchy(self):
40204030
def fn(x, y):
40214031
return torch.sum(1 / (torch.unsqueeze(x, -1) - y))
@@ -5277,6 +5287,7 @@ def fn(a):
52775287
y = fn_compiled(x)
52785288
self.assertTrue(y is not x)
52795289

5290+
@skipIfRocmArch(NAVI_ARCH)
52805291
def test_l1_loss(self):
52815292
def fn(a, b):
52825293
return torch.nn.functional.l1_loss(a, b), torch.nn.functional.mse_loss(a, b)
@@ -5673,6 +5684,7 @@ def fn(x):
56735684
fn, (torch.tensor([1, float("inf"), 2, float("-inf"), float("nan")]),)
56745685
)
56755686

5687+
@skipIfRocmArch(NAVI_ARCH)
56765688
def test_any(self):
56775689
def fn(x):
56785690
return (
@@ -6361,6 +6373,7 @@ def fn(a, dim, index, b, reduce):
63616373
)
63626374

63636375
# issue #1150
6376+
@skipIfRocmArch(NAVI_ARCH)
63646377
def test_dense_mask_index(self):
63656378
if self.device == "cpu":
63666379
raise unittest.SkipTest(
@@ -7216,6 +7229,7 @@ def fn(a, b):
72167229
b = torch.rand(2, 2, 1, 4, 1).int()
72177230
self.common(fn, (a, b))
72187231

7232+
@skipIfRocmArch(NAVI_ARCH)
72197233
def test_argmax_argmin1(self):
72207234
def fn(x):
72217235
return (aten.argmax(x), aten.argmin(x))
@@ -7227,6 +7241,7 @@ def fn(x):
72277241
],
72287242
)
72297243

7244+
@skipIfRocmArch(NAVI_ARCH)
72307245
def test_argmax_argmin2(self):
72317246
def fn(x):
72327247
return (
@@ -7238,6 +7253,7 @@ def fn(x):
72387253

72397254
self.common(fn, (torch.randn([144, 144]),))
72407255

7256+
@skipIfRocmArch(NAVI_ARCH)
72417257
def test_argmax_argmin_with_duplicates(self):
72427258
def fn(x):
72437259
return (
@@ -7259,6 +7275,7 @@ def fn(x):
72597275
t1 = torch.randint(8, size=(1028, 1028))
72607276
self.common(fn, (t1,))
72617277

7278+
@skipIfRocmArch(NAVI_ARCH)
72627279
def test_argmax_argmin_with_nan(self):
72637280
def fn(x):
72647281
return (
@@ -7391,6 +7408,7 @@ def fn(x):
73917408
],
73927409
)
73937410

7411+
@skipIfRocmArch(NAVI_ARCH)
73947412
def test_tmp_not_defined_issue1(self):
73957413
def forward(
73967414
primals_3,
@@ -7786,6 +7804,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
77867804
else:
77877805
self.assertEqual(len(inps), 0)
77887806

7807+
@skipIfRocmArch(NAVI_ARCH)
77897808
def test_dtype_mismatch_issue(self):
77907809
def fn(x):
77917810
attn = torch.nn.functional.pad(x, [0, 1])
@@ -9933,6 +9952,7 @@ def test_rnn_compile_safe(self):
99339952

99349953
class NanCheckerTest(TestCase):
99359954
@config.patch("nan_asserts", True)
9955+
@skipIfRocmArch(NAVI_ARCH)
99369956
def test_nan_checker_pass(self):
99379957
def f(x):
99389958
return torch.softmax(x, dim=-1)

test/inductor/test_torchinductor_opinfo.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from torch.testing._internal.common_methods_invocations import op_db, skipOps
3232
from torch.testing._internal.common_utils import (
3333
dtype_abbrs,
34+
IS_NAVI,
3435
IS_MACOS,
3536
IS_X86,
3637
skipCUDAMemoryLeakCheckIf,
@@ -201,6 +202,19 @@ def format_op(op):
201202
# Tensors are not alike
202203
inductor_skips["cuda"]["logcumsumexp"] = {f32}
203204
inductor_skips["cuda"]["special.modified_bessel_i1"] = {f64}
205+
if IS_NAVI:
206+
inductor_skips["cuda"]["aminmax"] = {b8, f16, f32, f64, i32, i64}
207+
inductor_skips["cuda"]["dist"] = {b8, f16, f32, f64, i32, i64}
208+
inductor_skips["cuda"]["kron"] = {b8, f16, f32, f64, i32, i64}
209+
inductor_skips["cuda"]["masked.std"] = {b8, f16, f32, f64, i32, i64}
210+
inductor_skips["cuda"]["masked.var"] = {b8, f16, f32, f64, i32, i64}
211+
inductor_skips["cuda"][("max", "reduction_no_dim")] = {b8, f16, f32, f64, i32, i64}
212+
inductor_skips["cuda"][("min", "reduction_no_dim")] = {b8, f16, f32, f64, i32, i64}
213+
inductor_skips["cuda"]["nn.functional.conv_transpose3d"] = {b8, f16, f32, f64, i32, i64}
214+
inductor_skips["cuda"]["std"] = {b8, f16, f32, f64, i32, i64}
215+
inductor_skips["cuda"]["std_mean"] = {b8, f16, f32, f64, i32, i64}
216+
inductor_skips["cuda"]["var"] = {b8, f16, f32, f64, i32, i64}
217+
inductor_skips["cuda"]["var_mean"] = {b8, f16, f32, f64, i32, i64}
204218

205219
inductor_expected_failures_single_sample = defaultdict(dict)
206220

torch/testing/_internal/common_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,6 +1171,13 @@ def printErrors(self) -> None:
11711171
IS_X86 = platform.machine() in ('x86_64', 'i386')
11721172
IS_ARM64 = platform.machine() in ('arm64', 'aarch64')
11731173

1174+
IS_NAVI=False
1175+
if torch.cuda.is_available():
1176+
prop = torch.cuda.get_device_properties(0)
1177+
gfx_arch = prop.gcnArchName.split(":")[0]
1178+
if gfx_arch in ["gfx1100", "gfx1101", "gfx1102"]:
1179+
IS_NAVI = True
1180+
11741181
def is_avx512_vnni_supported():
11751182
if sys.platform != 'linux':
11761183
return False
@@ -1545,6 +1552,19 @@ def wrapper(*args, **kwargs):
15451552
return dec_fn(func)
15461553
return dec_fn
15471554

1555+
def skipIfRocmArch(arch: Tuple[str, ...]):
1556+
def dec_fn(fn):
1557+
@wraps(fn)
1558+
def wrap_fn(self, *args, **kwargs):
1559+
if TEST_WITH_ROCM:
1560+
prop = torch.cuda.get_device_properties(0)
1561+
if prop.gcnArchName.split(":")[0] in arch:
1562+
reason = f"skipIfRocm: test skipped on {arch}"
1563+
raise unittest.SkipTest(reason)
1564+
return fn(self, *args, **kwargs)
1565+
return wrap_fn
1566+
return dec_fn
1567+
15481568
def runOnRocm(fn):
15491569
@wraps(fn)
15501570
def wrapper(*args, **kwargs):
@@ -1554,6 +1574,19 @@ def wrapper(*args, **kwargs):
15541574
raise unittest.SkipTest("test currently only works on the ROCm stack")
15551575
return wrapper
15561576

1577+
def runOnRocmArch(arch: Tuple[str, ...]):
1578+
def dec_fn(fn):
1579+
@wraps(fn)
1580+
def wrap_fn(self, *args, **kwargs):
1581+
if TEST_WITH_ROCM:
1582+
prop = torch.cuda.get_device_properties(0)
1583+
if prop.gcnArchName.split(":")[0] not in arch:
1584+
reason = f"skipIfRocm: test skipped on {arch}"
1585+
raise unittest.SkipTest(reason)
1586+
return fn(self, *args, **kwargs)
1587+
return wrap_fn
1588+
return dec_fn
1589+
15571590
def skipIfXpu(func=None, *, msg="test doesn't currently work on the XPU stack"):
15581591
def dec_fn(fn):
15591592
reason = f"skipIfXpu: {msg}"

0 commit comments

Comments
 (0)