Skip to content

Commit 67c8c9c

Browse files
tvukovic-amdpytorchmergebot
authored andcommitted
Enable torchvision build with ROCm on Windows (#6)
* changing flags according to clang * update changes * fix hipcc run on windows * changed include dirs conversion for rocm windows * fixed clang compilation with space in include paths * fix hip kernels build * restruct CXX and MSVC flags * change cxx flags * change include headers in hipified source files * change cuda to hip in include paths * cleaning code * fixed win_wrap_single_compile * fix paths in hipify * format code * removing unnecessary flags * cleaning code * fix cleaning * cleaning code * cleaning code * restructured ninja commands * renaming hipcc function * added set runtime lib function * add set hipcc runtime library function * format code
1 parent 790f93d commit 67c8c9c

File tree

2 files changed

+67
-26
lines changed

2 files changed

+67
-26
lines changed

torch/utils/cpp_extension.py

Lines changed: 64 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,6 @@ def _join_rocm_home(*paths) -> str:
173173
if ROCM_HOME is None:
174174
raise OSError('ROCM_HOME environment variable is not set. '
175175
'Please set it to your ROCm install root.')
176-
elif IS_WINDOWS:
177-
raise OSError('Building PyTorch extensions using '
178-
'ROCm and Windows is not supported.')
179176
return os.path.join(ROCM_HOME, *paths)
180177

181178
def _join_sycl_home(*paths) -> str:
@@ -270,12 +267,14 @@ def _join_sycl_home(*paths) -> str:
270267
]
271268

272269
COMMON_HIP_FLAGS = [
273-
'-fPIC',
274270
'-D__HIP_PLATFORM_AMD__=1',
275271
'-DUSE_ROCM=1',
276272
'-DHIPBLAS_V2',
277273
]
278274

275+
if not IS_WINDOWS:
276+
COMMON_HIP_FLAGS.append('-fPIC')
277+
279278
COMMON_HIPCC_FLAGS = [
280279
'-DCUDA_HAS_FP16=1',
281280
'-D__HIP_NO_HALF_OPERATORS__=1',
@@ -511,6 +510,19 @@ def _check_cuda_version(compiler_name: str, compiler_version: TorchVersion) -> N
511510
f'Please make sure to use an adequate version of {compiler_name} ({version_bound_str}).'
512511
)
513512

513+
# Specify Visual Studio C runtime library for hipcc
514+
def _set_hipcc_runtime_lib(is_standalone):
515+
debug = os.getenv("DEBUG")
516+
if is_standalone:
517+
if debug:
518+
COMMON_HIP_FLAGS.append('-fms-runtime-lib=static_dbg')
519+
else:
520+
COMMON_HIP_FLAGS.append('-fms-runtime-lib=static')
521+
else:
522+
if debug:
523+
COMMON_HIP_FLAGS.append('-fms-runtime-lib=dll_dbg')
524+
else:
525+
COMMON_HIP_FLAGS.append('-fms-runtime-lib=dll')
514526

515527
def _append_sycl_std_if_no_std_present(cflags):
516528
if not any(flag.startswith('-sycl-std=') for flag in cflags):
@@ -831,6 +843,9 @@ def unix_wrap_ninja_compile(sources,
831843
def win_cuda_flags(cflags):
832844
return (COMMON_NVCC_FLAGS +
833845
cflags + _get_cuda_arch_flags(cflags))
846+
847+
def win_hip_flags(cflags):
848+
return (COMMON_HIPCC_FLAGS + COMMON_HIP_FLAGS + cflags + _get_rocm_arch_flags(cflags))
834849

835850
def win_wrap_single_compile(sources,
836851
output_dir=None,
@@ -868,19 +883,25 @@ def spawn(cmd):
868883
src = src_list[0]
869884
obj = obj_list[0]
870885
if _is_cuda_file(src):
871-
nvcc = _join_cuda_home('bin', 'nvcc')
886+
if IS_HIP_EXTENSION:
887+
nvcc = _get_hipcc_path()
888+
else:
889+
nvcc = _join_cuda_home('bin', 'nvcc')
872890
if isinstance(self.cflags, dict):
873891
cflags = self.cflags['nvcc']
874892
elif isinstance(self.cflags, list):
875893
cflags = self.cflags
876894
else:
877895
cflags = []
878896

879-
cflags = win_cuda_flags(cflags) + ['-std=c++17', '--use-local-env']
897+
if IS_HIP_EXTENSION:
898+
cflags = win_hip_flags(cflags)
899+
else:
900+
cflags = win_cuda_flags(cflags) + ['-std=c++17', '--use-local-env']
901+
for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS:
902+
cflags = ['-Xcudafe', '--diag_suppress=' + ignore_warning] + cflags
880903
for flag in COMMON_MSVC_FLAGS:
881904
cflags = ['-Xcompiler', flag] + cflags
882-
for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS:
883-
cflags = ['-Xcudafe', '--diag_suppress=' + ignore_warning] + cflags
884905
cmd = [nvcc, '-c', src, '-o', obj] + include_list + cflags
885906
elif isinstance(self.cflags, dict):
886907
cflags = COMMON_MSVC_FLAGS + self.cflags['cxx']
@@ -909,7 +930,6 @@ def win_wrap_ninja_compile(sources,
909930
extra_preargs=None,
910931
extra_postargs=None,
911932
depends=None):
912-
913933
if not self.compiler.initialized:
914934
self.compiler.initialize()
915935
output_dir = os.path.abspath(output_dir)
@@ -926,14 +946,20 @@ def win_wrap_ninja_compile(sources,
926946
self.compiler._setup_compile(output_dir, macros,
927947
include_dirs, sources,
928948
depends, extra_postargs)
949+
# Replace space with \ when using hipcc (hipcc passes includes to clang without ""s so clang sees space in include paths as new argument)
950+
if IS_HIP_EXTENSION:
951+
pp_opts = ["-I{}".format(s[2:].replace(" ", "\\")) if s.startswith('-I') else s for s in pp_opts]
929952
common_cflags = extra_preargs or []
930953
cflags = []
931954
if debug:
932955
cflags.extend(self.compiler.compile_options_debug)
933956
else:
934957
cflags.extend(self.compiler.compile_options)
935-
common_cflags.extend(COMMON_MSVC_FLAGS)
936-
cflags = cflags + common_cflags + pp_opts
958+
cflags = cflags + common_cflags + pp_opts + COMMON_MSVC_FLAGS
959+
if IS_HIP_EXTENSION:
960+
common_cflags.extend(COMMON_HIP_FLAGS)
961+
else:
962+
common_cflags.extend(COMMON_MSVC_FLAGS)
937963
with_cuda = any(map(_is_cuda_file, sources))
938964

939965
# extra_postargs can be either:
@@ -943,25 +969,31 @@ def win_wrap_ninja_compile(sources,
943969
post_cflags = extra_postargs['cxx']
944970
else:
945971
post_cflags = list(extra_postargs)
972+
if IS_HIP_EXTENSION:
973+
post_cflags = COMMON_HIP_FLAGS + post_cflags
946974
append_std17_if_no_std_present(post_cflags)
947975

948976
cuda_post_cflags = None
949977
cuda_cflags = None
950978
if with_cuda:
951-
cuda_cflags = ['-std=c++17', '--use-local-env']
979+
cuda_cflags = ['-std=c++17']
952980
for common_cflag in common_cflags:
953981
cuda_cflags.append('-Xcompiler')
954982
cuda_cflags.append(common_cflag)
955-
for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS:
956-
cuda_cflags.append('-Xcudafe')
957-
cuda_cflags.append('--diag_suppress=' + ignore_warning)
983+
if not IS_HIP_EXTENSION:
984+
cuda_cflags.append('--use-local-env')
985+
for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS:
986+
cuda_cflags.append('-Xcudafe')
987+
cuda_cflags.append('--diag_suppress=' + ignore_warning)
958988
cuda_cflags.extend(pp_opts)
959989
if isinstance(extra_postargs, dict):
960990
cuda_post_cflags = extra_postargs['nvcc']
961991
else:
962992
cuda_post_cflags = list(extra_postargs)
963-
cuda_post_cflags = win_cuda_flags(cuda_post_cflags)
964-
993+
if IS_HIP_EXTENSION:
994+
cuda_post_cflags = win_hip_flags(cuda_post_cflags)
995+
else:
996+
cuda_post_cflags = win_cuda_flags(cuda_post_cflags)
965997
cflags = _nt_quote_args(cflags)
966998
post_cflags = _nt_quote_args(post_cflags)
967999
if with_cuda:
@@ -990,7 +1022,6 @@ def win_wrap_ninja_compile(sources,
9901022

9911023
# Return *all* object filenames, not just the ones we just built.
9921024
return objects
993-
9941025
# Monkey-patch the _compile or compile method.
9951026
# https://github.com/python/cpython/blob/dc0284ee8f7a270b6005467f26d8e5773d76e959/Lib/distutils/ccompiler.py#L511
9961027
if self.compiler.compiler_type == 'msvc':
@@ -2076,7 +2107,12 @@ def _jit_compile(name,
20762107

20772108
return _import_module_from_library(name, build_directory, is_python_module)
20782109

2079-
2110+
def _get_hipcc_path():
2111+
if IS_WINDOWS:
2112+
return _join_rocm_home('bin', 'hipcc.bat')
2113+
else:
2114+
return _join_rocm_home('bin', 'hipcc')
2115+
20802116
def _write_ninja_file_and_compile_objects(
20812117
sources: list[str],
20822118
objects,
@@ -2479,6 +2515,7 @@ def _run_ninja_build(build_directory: str, verbose: bool, error_prefix: str) ->
24792515
stdout_fileno = 1
24802516
subprocess.run(
24812517
command,
2518+
shell=IS_WINDOWS and IS_HIP_EXTENSION,
24822519
stdout=stdout_fileno if verbose else subprocess.PIPE,
24832520
stderr=subprocess.STDOUT,
24842521
cwd=build_directory,
@@ -2573,7 +2610,8 @@ def _write_ninja_file_to_build_library(path,
25732610
common_cflags += [f"{x}" for x in _get_glibcxx_abi_build_flags()]
25742611

25752612
if IS_WINDOWS:
2576-
cflags = common_cflags + COMMON_MSVC_FLAGS + ['/std:c++17'] + extra_cflags
2613+
cflags = common_cflags + ['/std:c++17'] + extra_cflags
2614+
clfags += COMMON_HIP_FLAGS if IS_HIP_EXTENSION else COMMON_MSVC_FLAGS
25772615
cflags = _nt_quote_args(cflags)
25782616
else:
25792617
cflags = common_cflags + ['-fPIC', '-std=c++17'] + extra_cflags
@@ -2632,6 +2670,7 @@ def object_file_path(source_file: str) -> str:
26322670

26332671
objects = [object_file_path(src) for src in sources]
26342672
ldflags = ([] if is_standalone else [SHARED_FLAG]) + extra_ldflags
2673+
_set_hipcc_runtime_lib(is_standalone)
26352674

26362675
# The darwin linker needs explicit consent to ignore unresolved symbols.
26372676
if IS_MACOS:
@@ -2724,7 +2763,7 @@ def sanitize_flags(flags):
27242763
nvcc = os.getenv("PYTORCH_NVCC") # user can set nvcc compiler with ccache using the environment variable here
27252764
else:
27262765
if IS_HIP_EXTENSION:
2727-
nvcc = _join_rocm_home('bin', 'hipcc')
2766+
nvcc = _get_hipcc_path()
27282767
else:
27292768
nvcc = _join_cuda_home('bin', 'nvcc')
27302769
config.append(f'nvcc = {nvcc}')
@@ -2753,9 +2792,11 @@ def sanitize_flags(flags):
27532792
# See https://ninja-build.org/build.ninja.html for reference.
27542793
compile_rule = ['rule compile']
27552794
if IS_WINDOWS:
2795+
compiler_name = "$cxx" if IS_HIP_EXTENSION else "cl"
27562796
compile_rule.append(
2757-
' command = cl /showIncludes $cflags -c $in /Fo$out $post_cflags')
2758-
compile_rule.append(' deps = msvc')
2797+
f' command = {compiler_name} /showIncludes $cflags -c $in /Fo$out $post_cflags')
2798+
if not IS_HIP_EXTENSION:
2799+
compile_rule.append(' deps = msvc')
27592800
else:
27602801
compile_rule.append(
27612802
' command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags')

torch/utils/hipify/hipify_python.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@ def __exit__(self, type, value, traceback):
139139
for d in self.dirs_to_clean[::-1]:
140140
os.rmdir(d)
141141

142-
143142
# Follow UNIX convention for paths to use '/' instead of '\\' on Windows
144143
def _to_unix_path(path: str) -> str:
145144
return path.replace(os.sep, '/')
@@ -830,6 +829,7 @@ def preprocessor(
830829
show_progress: bool) -> HipifyResult:
831830
""" Executes the CUDA -> HIP conversion on the specified file. """
832831
fin_path = os.path.abspath(os.path.join(output_directory, filepath))
832+
filepath = _to_unix_path(filepath)
833833
hipify_result = HIPIFY_FINAL_RESULT[fin_path]
834834
if filepath not in all_files:
835835
hipify_result.hipified_path = None
@@ -932,8 +932,8 @@ def repl(m):
932932
return templ.format(os.path.relpath(header_fout_path if header_fout_path is not None
933933
else header_filepath, header_dir))
934934
hipified_header_filepath = HIPIFY_FINAL_RESULT[header_filepath].hipified_path
935-
return templ.format(os.path.relpath(hipified_header_filepath if hipified_header_filepath is not None
936-
else header_filepath, header_dir))
935+
return templ.format(_to_unix_path(os.path.relpath(hipified_header_filepath if hipified_header_filepath is not None
936+
else header_filepath, header_dir)))
937937

938938
return m.group(0)
939939
return repl

0 commit comments

Comments
 (0)