diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index f2d29140987bcd..cc8c167f0a78dd 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -31,9 +31,9 @@ from torch.testing._internal.common_methods_invocations import op_db, skipOps from torch.testing._internal.common_utils import ( dtype_abbrs, - IS_NAVI, IS_MACOS, IS_X86, + is_navi_arch, skipCUDAMemoryLeakCheckIf, skipIfCrossRef, skipIfTorchDynamo, @@ -201,7 +201,7 @@ def format_op(op): if TEST_WITH_ROCM: # Tensors are not alike inductor_skips["cuda"]["logcumsumexp"] = {f32} - if IS_NAVI: + if is_navi_arch(): inductor_skips["cuda"]["aminmax"] = {b8, f16, f32, f64, i32, i64} inductor_skips["cuda"]["dist"] = {b8, f16, f32, f64, i32, i64} inductor_skips["cuda"]["kron"] = {b8, f16, f32, f64, i32, i64} diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index c909dcc0665a76..7801882f60a531 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1176,12 +1176,13 @@ def printErrors(self) -> None: IS_X86 = platform.machine() in ('x86_64', 'i386') IS_ARM64 = platform.machine() in ('arm64', 'aarch64') -IS_NAVI=False -if torch.cuda.is_available(): - prop = torch.cuda.get_device_properties(0) - gfx_arch = prop.gcnArchName.split(":")[0] - if gfx_arch in ["gfx1100", "gfx1101", "gfx1102"]: - IS_NAVI = True +def is_navi_arch(): + if torch.cuda.is_available(): + prop = torch.cuda.get_device_properties(0) + gfx_arch = prop.gcnArchName.split(":")[0] + if gfx_arch in ["gfx1100", "gfx1101", "gfx1102"]: + return True + return False def is_avx512_vnni_supported(): if sys.platform != 'linux':