|
16 | 16 | from .file_baton import FileBaton
|
17 | 17 | from ._cpp_extension_versioner import ExtensionVersioner
|
18 | 18 | from .hipify import hipify_python
|
19 |
| -from .hipify.hipify_python import get_hip_file_path, GeneratedFileCleaner |
| 19 | +from .hipify.hipify_python import GeneratedFileCleaner |
20 | 20 | from typing import List, Optional, Union
|
21 | 21 |
|
22 | 22 | from setuptools.command.build_ext import build_ext
|
@@ -945,16 +945,19 @@ def CUDAExtension(name, sources, *args, **kwargs):
|
945 | 945 | hipify_result = hipify_python.hipify(
|
946 | 946 | project_directory=build_dir,
|
947 | 947 | output_directory=build_dir,
|
948 |
| - includes=[os.path.join(os.path.relpath(include_dir, build_dir), '*') for include_dir in include_dirs] if include_dirs else ['*'], |
| 948 | + header_include_dirs=include_dirs, |
| 949 | + includes=[os.path.join(build_dir, '*')], # limit scope to build_dir only |
949 | 950 | extra_files=[os.path.abspath(s) for s in sources],
|
950 | 951 | show_detailed=True,
|
951 | 952 | is_pytorch_extension=True,
|
| 953 | + hipify_extra_files_only=True, # don't hipify everything in includes path |
952 | 954 | )
|
953 | 955 |
|
954 | 956 | hipified_sources = set()
|
955 | 957 | for source in sources:
|
956 | 958 | s_abs = os.path.abspath(source)
|
957 |
| - hipified_sources.add(hipify_result[s_abs]["hipified_path"] if s_abs in hipify_result else s_abs) |
| 959 | + hipified_sources.add(hipify_result[s_abs]["hipified_path"] if (s_abs in hipify_result and |
| 960 | + hipify_result[s_abs]["hipified_path"] is not None) else s_abs) |
958 | 961 |
|
959 | 962 | sources = list(hipified_sources)
|
960 | 963 |
|
@@ -1331,15 +1334,25 @@ def _jit_compile(name,
|
1331 | 1334 | try:
|
1332 | 1335 | with GeneratedFileCleaner(keep_intermediates=keep_intermediates) as clean_ctx:
|
1333 | 1336 | if IS_HIP_EXTENSION and (with_cuda or with_cudnn):
|
1334 |
| - hipify_python.hipify( |
| 1337 | + hipify_result = hipify_python.hipify( |
1335 | 1338 | project_directory=build_directory,
|
1336 | 1339 | output_directory=build_directory,
|
1337 |
| - includes=os.path.join(build_directory, '*'), |
| 1340 | + header_include_dirs=extra_include_paths, |
1338 | 1341 | extra_files=[os.path.abspath(s) for s in sources],
|
| 1342 | + ignores=[os.path.join(ROCM_HOME, '*'), os.path.join(_TORCH_PATH, '*')], # no need to hipify ROCm or PyTorch headers |
1339 | 1343 | show_detailed=verbose,
|
| 1344 | + show_progress=verbose, |
1340 | 1345 | is_pytorch_extension=True,
|
1341 | 1346 | clean_ctx=clean_ctx
|
1342 | 1347 | )
|
| 1348 | + |
| 1349 | + hipified_sources = set() |
| 1350 | + for source in sources: |
| 1351 | + s_abs = os.path.abspath(source) |
| 1352 | + hipified_sources.add(hipify_result[s_abs]["hipified_path"] if s_abs in hipify_result else s_abs) |
| 1353 | + |
| 1354 | + sources = list(hipified_sources) |
| 1355 | + |
1343 | 1356 | _write_ninja_file_and_build_library(
|
1344 | 1357 | name=name,
|
1345 | 1358 | sources=sources,
|
@@ -1835,10 +1848,6 @@ def _write_ninja_file_to_build_library(path,
|
1835 | 1848 | cuda_flags = ['-DWITH_HIP'] + cflags + COMMON_HIP_FLAGS + COMMON_HIPCC_FLAGS
|
1836 | 1849 | cuda_flags += extra_cuda_cflags
|
1837 | 1850 | cuda_flags += _get_rocm_arch_flags(cuda_flags)
|
1838 |
| - sources = [s if not _is_cuda_file(s) else |
1839 |
| - os.path.abspath(os.path.join( |
1840 |
| - path, get_hip_file_path(os.path.relpath(s, path), is_pytorch_extension=True))) |
1841 |
| - for s in sources] |
1842 | 1851 | elif with_cuda:
|
1843 | 1852 | cuda_flags = common_cflags + COMMON_NVCC_FLAGS + _get_cuda_arch_flags()
|
1844 | 1853 | if IS_WINDOWS:
|
@@ -1949,6 +1958,8 @@ def sanitize_flags(flags):
|
1949 | 1958 | nvcc = _join_cuda_home('bin', 'nvcc')
|
1950 | 1959 | config.append(f'nvcc = {nvcc}')
|
1951 | 1960 |
|
| 1961 | + if IS_HIP_EXTENSION: |
| 1962 | + post_cflags = COMMON_HIP_FLAGS + post_cflags |
1952 | 1963 | flags = [f'cflags = {" ".join(cflags)}']
|
1953 | 1964 | flags.append(f'post_cflags = {" ".join(post_cflags)}')
|
1954 | 1965 | if with_cuda:
|
|
0 commit comments