Skip to content

Commit a41829d

Browse files
Fix JIT path for Pytorch extensions and other hipify fixes (rocm5.0_internal_testing) (#906)
* Fix JIT path for Pytorch extensions and other hipify fixes * Typo during cherry-pick conflict resolution * is_cusparse_file should take a rel_path
1 parent 6cec426 commit a41829d

File tree

3 files changed

+130
-109
lines changed

3 files changed

+130
-109
lines changed

tools/amd_build/build_amd.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@
8989
"tools/autograd/templates/python_variable_methods.cpp",
9090
]
9191

92+
includes = [os.path.join(proj_dir, include) for include in includes]
93+
9294
for new_dir in args.extra_include_dir:
9395
abs_new_dir = os.path.join(proj_dir, new_dir)
9496
if os.path.exists(abs_new_dir):
@@ -112,6 +114,8 @@
112114
"torch/include/*",
113115
]
114116

117+
ignores = [os.path.join(proj_dir, ignore) for ignore in ignores]
118+
115119
# Check if the compiler is hip-clang.
116120
def is_hip_clang() -> bool:
117121
try:

torch/utils/cpp_extension.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from .file_baton import FileBaton
1717
from ._cpp_extension_versioner import ExtensionVersioner
1818
from .hipify import hipify_python
19-
from .hipify.hipify_python import get_hip_file_path, GeneratedFileCleaner
19+
from .hipify.hipify_python import GeneratedFileCleaner
2020
from typing import List, Optional, Union
2121

2222
from setuptools.command.build_ext import build_ext
@@ -945,16 +945,19 @@ def CUDAExtension(name, sources, *args, **kwargs):
945945
hipify_result = hipify_python.hipify(
946946
project_directory=build_dir,
947947
output_directory=build_dir,
948-
includes=[os.path.join(os.path.relpath(include_dir, build_dir), '*') for include_dir in include_dirs] if include_dirs else ['*'],
948+
header_include_dirs=include_dirs,
949+
includes=[os.path.join(build_dir, '*')], # limit scope to build_dir only
949950
extra_files=[os.path.abspath(s) for s in sources],
950951
show_detailed=True,
951952
is_pytorch_extension=True,
953+
hipify_extra_files_only=True, # don't hipify everything in includes path
952954
)
953955

954956
hipified_sources = set()
955957
for source in sources:
956958
s_abs = os.path.abspath(source)
957-
hipified_sources.add(hipify_result[s_abs]["hipified_path"] if s_abs in hipify_result else s_abs)
959+
hipified_sources.add(hipify_result[s_abs]["hipified_path"] if (s_abs in hipify_result and
960+
hipify_result[s_abs]["hipified_path"] is not None) else s_abs)
958961

959962
sources = list(hipified_sources)
960963

@@ -1331,15 +1334,25 @@ def _jit_compile(name,
13311334
try:
13321335
with GeneratedFileCleaner(keep_intermediates=keep_intermediates) as clean_ctx:
13331336
if IS_HIP_EXTENSION and (with_cuda or with_cudnn):
1334-
hipify_python.hipify(
1337+
hipify_result = hipify_python.hipify(
13351338
project_directory=build_directory,
13361339
output_directory=build_directory,
1337-
includes=os.path.join(build_directory, '*'),
1340+
header_include_dirs=extra_include_paths,
13381341
extra_files=[os.path.abspath(s) for s in sources],
1342+
ignores=[os.path.join(ROCM_HOME, '*'), os.path.join(_TORCH_PATH, '*')], # no need to hipify ROCm or PyTorch headers
13391343
show_detailed=verbose,
1344+
show_progress=verbose,
13401345
is_pytorch_extension=True,
13411346
clean_ctx=clean_ctx
13421347
)
1348+
1349+
hipified_sources = set()
1350+
for source in sources:
1351+
s_abs = os.path.abspath(source)
1352+
hipified_sources.add(hipify_result[s_abs]["hipified_path"] if s_abs in hipify_result else s_abs)
1353+
1354+
sources = list(hipified_sources)
1355+
13431356
_write_ninja_file_and_build_library(
13441357
name=name,
13451358
sources=sources,
@@ -1835,10 +1848,6 @@ def _write_ninja_file_to_build_library(path,
18351848
cuda_flags = ['-DWITH_HIP'] + cflags + COMMON_HIP_FLAGS + COMMON_HIPCC_FLAGS
18361849
cuda_flags += extra_cuda_cflags
18371850
cuda_flags += _get_rocm_arch_flags(cuda_flags)
1838-
sources = [s if not _is_cuda_file(s) else
1839-
os.path.abspath(os.path.join(
1840-
path, get_hip_file_path(os.path.relpath(s, path), is_pytorch_extension=True)))
1841-
for s in sources]
18421851
elif with_cuda:
18431852
cuda_flags = common_cflags + COMMON_NVCC_FLAGS + _get_cuda_arch_flags()
18441853
if IS_WINDOWS:
@@ -1949,6 +1958,8 @@ def sanitize_flags(flags):
19491958
nvcc = _join_cuda_home('bin', 'nvcc')
19501959
config.append(f'nvcc = {nvcc}')
19511960

1961+
if IS_HIP_EXTENSION:
1962+
post_cflags = COMMON_HIP_FLAGS + post_cflags
19521963
flags = [f'cflags = {" ".join(cflags)}']
19531964
flags.append(f'post_cflags = {" ".join(post_cflags)}')
19541965
if with_cuda:

0 commit comments

Comments
 (0)