Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tools/amd_build/build_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@
"tools/autograd/templates/python_variable_methods.cpp",
]

includes = [os.path.join(proj_dir, include) for include in includes]

for new_dir in args.extra_include_dir:
abs_new_dir = os.path.join(proj_dir, new_dir)
if os.path.exists(abs_new_dir):
Expand All @@ -112,6 +114,8 @@
"torch/include/*",
]

ignores = [os.path.join(proj_dir, ignore) for ignore in ignores]

# Check if the compiler is hip-clang.
def is_hip_clang() -> bool:
try:
Expand Down
29 changes: 20 additions & 9 deletions torch/utils/cpp_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .file_baton import FileBaton
from ._cpp_extension_versioner import ExtensionVersioner
from .hipify import hipify_python
from .hipify.hipify_python import get_hip_file_path, GeneratedFileCleaner
from .hipify.hipify_python import GeneratedFileCleaner
from typing import List, Optional, Union

from setuptools.command.build_ext import build_ext
Expand Down Expand Up @@ -939,16 +939,19 @@ def CUDAExtension(name, sources, *args, **kwargs):
hipify_result = hipify_python.hipify(
project_directory=build_dir,
output_directory=build_dir,
includes=[os.path.join(os.path.relpath(include_dir, build_dir), '*') for include_dir in include_dirs] if include_dirs else ['*'],
header_include_dirs=include_dirs,
includes=[os.path.join(build_dir, '*')], # limit scope to build_dir only
extra_files=[os.path.abspath(s) for s in sources],
show_detailed=True,
is_pytorch_extension=True,
hipify_extra_files_only=True, # don't hipify everything in includes path
)

hipified_sources = set()
for source in sources:
s_abs = os.path.abspath(source)
hipified_sources.add(hipify_result[s_abs]["hipified_path"] if s_abs in hipify_result else s_abs)
hipified_sources.add(hipify_result[s_abs]["hipified_path"] if (s_abs in hipify_result and
hipify_result[s_abs]["hipified_path"] is not None) else s_abs)

sources = list(hipified_sources)

Expand Down Expand Up @@ -1325,15 +1328,25 @@ def _jit_compile(name,
try:
with GeneratedFileCleaner(keep_intermediates=keep_intermediates) as clean_ctx:
if IS_HIP_EXTENSION and (with_cuda or with_cudnn):
hipify_python.hipify(
hipify_result = hipify_python.hipify(
project_directory=build_directory,
output_directory=build_directory,
includes=os.path.join(build_directory, '*'),
header_include_dirs=extra_include_paths,
extra_files=[os.path.abspath(s) for s in sources],
ignores=[os.path.join(ROCM_HOME, '*'), os.path.join(_TORCH_PATH, '*')], # no need to hipify ROCm or PyTorch headers
show_detailed=verbose,
show_progress=verbose,
is_pytorch_extension=True,
clean_ctx=clean_ctx
)

hipified_sources = set()
for source in sources:
s_abs = os.path.abspath(source)
hipified_sources.add(hipify_result[s_abs]["hipified_path"] if s_abs in hipify_result else s_abs)

sources = list(hipified_sources)

_write_ninja_file_and_build_library(
name=name,
sources=sources,
Expand Down Expand Up @@ -1826,10 +1839,6 @@ def _write_ninja_file_to_build_library(path,
cuda_flags = ['-DWITH_HIP'] + cflags + COMMON_HIP_FLAGS + COMMON_HIPCC_FLAGS
cuda_flags += extra_cuda_cflags
cuda_flags += _get_rocm_arch_flags(cuda_flags)
sources = [s if not _is_cuda_file(s) else
os.path.abspath(os.path.join(
path, get_hip_file_path(os.path.relpath(s, path), is_pytorch_extension=True)))
for s in sources]
elif with_cuda:
cuda_flags = common_cflags + COMMON_NVCC_FLAGS + _get_cuda_arch_flags()
if IS_WINDOWS:
Expand Down Expand Up @@ -1940,6 +1949,8 @@ def sanitize_flags(flags):
nvcc = _join_cuda_home('bin', 'nvcc')
config.append(f'nvcc = {nvcc}')

if IS_HIP_EXTENSION:
post_cflags = COMMON_HIP_FLAGS + post_cflags
flags = [f'cflags = {" ".join(cflags)}']
flags.append(f'post_cflags = {" ".join(post_cflags)}')
if with_cuda:
Expand Down
Loading