Skip to content

Commit dc609b4

Browse files
BLOrange-AMDpruthvistony
authored andcommitted
Converted NAVI check as a function (#1364)
* Moved NAVI check to the test file * Revised NAVI check as a function
1 parent 90ac508 commit dc609b4

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

test/inductor/test_torchinductor_opinfo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@
3131
from torch.testing._internal.common_methods_invocations import op_db, skipOps
3232
from torch.testing._internal.common_utils import (
3333
dtype_abbrs,
34-
IS_NAVI,
3534
IS_MACOS,
3635
IS_X86,
36+
is_navi_arch,
3737
skipCUDAMemoryLeakCheckIf,
3838
skipIfCrossRef,
3939
skipIfTorchDynamo,
@@ -204,7 +204,7 @@ def format_op(op):
204204
# Tensors are not alike
205205
inductor_skips["cuda"]["logcumsumexp"] = {f32}
206206
inductor_skips["cuda"]["special.modified_bessel_i1"] = {f64}
207-
if IS_NAVI:
207+
if is_navi_arch():
208208
inductor_skips["cuda"]["aminmax"] = {b8, f16, f32, f64, i32, i64}
209209
inductor_skips["cuda"]["dist"] = {b8, f16, f32, f64, i32, i64}
210210
inductor_skips["cuda"]["kron"] = {b8, f16, f32, f64, i32, i64}

torch/testing/_internal/common_utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,12 +1177,13 @@ def printErrors(self) -> None:
11771177
IS_X86 = platform.machine() in ('x86_64', 'i386')
11781178
IS_ARM64 = platform.machine() in ('arm64', 'aarch64')
11791179

1180-
IS_NAVI=False
1181-
if torch.cuda.is_available():
1182-
prop = torch.cuda.get_device_properties(0)
1183-
gfx_arch = prop.gcnArchName.split(":")[0]
1184-
if gfx_arch in ["gfx1100", "gfx1101", "gfx1102"]:
1185-
IS_NAVI = True
1180+
def is_navi_arch():
1181+
if torch.cuda.is_available():
1182+
prop = torch.cuda.get_device_properties(0)
1183+
gfx_arch = prop.gcnArchName.split(":")[0]
1184+
if gfx_arch in ["gfx1100", "gfx1101", "gfx1102"]:
1185+
return True
1186+
return False
11861187

11871188
def is_avx512_vnni_supported():
11881189
if sys.platform != 'linux':

0 commit comments

Comments
 (0)