Skip to content

Commit d0d25c3

Browse files
BLOrange-AMDpruthvistony
authored andcommitted
Converted NAVI check as a function (#1364)
* Moved NAVI check to the test file * Revised NAVI check as a function
1 parent 0b08ab6 commit d0d25c3

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

test/inductor/test_torchinductor_opinfo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@
3131
from torch.testing._internal.common_methods_invocations import op_db, skipOps
3232
from torch.testing._internal.common_utils import (
3333
dtype_abbrs,
34-
IS_NAVI,
3534
IS_MACOS,
3635
IS_X86,
36+
is_navi_arch,
3737
skipCUDAMemoryLeakCheckIf,
3838
skipIfCrossRef,
3939
skipIfTorchDynamo,
@@ -202,7 +202,7 @@ def format_op(op):
202202
# Tensors are not alike
203203
inductor_skips["cuda"]["logcumsumexp"] = {f32}
204204
inductor_skips["cuda"]["special.modified_bessel_i1"] = {f64}
205-
if IS_NAVI:
205+
if is_navi_arch():
206206
inductor_skips["cuda"]["aminmax"] = {b8, f16, f32, f64, i32, i64}
207207
inductor_skips["cuda"]["dist"] = {b8, f16, f32, f64, i32, i64}
208208
inductor_skips["cuda"]["kron"] = {b8, f16, f32, f64, i32, i64}

torch/testing/_internal/common_utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,12 +1171,13 @@ def printErrors(self) -> None:
11711171
IS_X86 = platform.machine() in ('x86_64', 'i386')
11721172
IS_ARM64 = platform.machine() in ('arm64', 'aarch64')
11731173

1174-
IS_NAVI=False
1175-
if torch.cuda.is_available():
1176-
prop = torch.cuda.get_device_properties(0)
1177-
gfx_arch = prop.gcnArchName.split(":")[0]
1178-
if gfx_arch in ["gfx1100", "gfx1101", "gfx1102"]:
1179-
IS_NAVI = True
1174+
def is_navi_arch():
1175+
if torch.cuda.is_available():
1176+
prop = torch.cuda.get_device_properties(0)
1177+
gfx_arch = prop.gcnArchName.split(":")[0]
1178+
if gfx_arch in ["gfx1100", "gfx1101", "gfx1102"]:
1179+
return True
1180+
return False
11801181

11811182
def is_avx512_vnni_supported():
11821183
if sys.platform != 'linux':

0 commit comments

Comments
 (0)