Skip to content

Commit 5b77292

Browse files
BLOrange-AMDdnikolaev-amd
authored andcommitted
Converted NAVI check as a function (#1364)
* Moved NAVI check to the test file * Revised NAVI check as a function
1 parent 8a4d1e2 commit 5b77292

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

test/inductor/test_torchinductor_opinfo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@
3030
from torch.testing._internal.common_methods_invocations import op_db, skipOps
3131
from torch.testing._internal.common_utils import (
3232
dtype_abbrs,
33-
IS_NAVI,
3433
IS_MACOS,
3534
IS_X86,
35+
is_navi_arch,
3636
skipCUDAMemoryLeakCheckIf,
3737
skipIfCrossRef,
3838
skipIfTorchDynamo,
@@ -204,7 +204,7 @@ def format_op(op):
204204
# Tensors are not alike
205205
inductor_skips["cuda"]["logcumsumexp"] = {f32}
206206
inductor_skips["cuda"]["special.modified_bessel_i1"] = {f64}
207-
if IS_NAVI:
207+
if is_navi_arch():
208208
inductor_skips["cuda"]["aminmax"] = {b8, f16, f32, f64, i32, i64}
209209
inductor_skips["cuda"]["dist"] = {b8, f16, f32, f64, i32, i64}
210210
inductor_skips["cuda"]["kron"] = {b8, f16, f32, f64, i32, i64}

torch/testing/_internal/common_utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,12 +1279,13 @@ def printErrors(self) -> None:
12791279
IS_X86 = platform.machine() in ('x86_64', 'i386')
12801280
IS_ARM64 = platform.machine() in ('arm64', 'aarch64')
12811281

1282-
IS_NAVI=False
1283-
if torch.cuda.is_available():
1284-
prop = torch.cuda.get_device_properties(0)
1285-
gfx_arch = prop.gcnArchName.split(":")[0]
1286-
if gfx_arch in ["gfx1100", "gfx1101", "gfx1102"]:
1287-
IS_NAVI = True
1282+
def is_navi_arch():
1283+
if torch.cuda.is_available():
1284+
prop = torch.cuda.get_device_properties(0)
1285+
gfx_arch = prop.gcnArchName.split(":")[0]
1286+
if gfx_arch in ["gfx1100", "gfx1101", "gfx1102"]:
1287+
return True
1288+
return False
12881289

12891290
def is_avx512_vnni_supported():
12901291
if sys.platform != 'linux':

0 commit comments

Comments
 (0)