Skip to content

Commit 9939d0b

Browse files
committed
Enable torchvision build with ROCm on Windows (#6)
* changing flags according to clang * update changes * fix hipcc run on windows * changed include dirs conversion for rocm windows * fixed clang compilation with space in include paths * fix hip kernels build * restruct CXX and MSVC flags * change cxx flags * change include headers in hipified source files * change cuda to hip in include paths * cleaning code * fixed win_wrap_single_compile * fix paths in hipify * format code * removing unnecessary flags * cleaning code * fix cleaning * cleaning code * cleaning code * restructured ninja commands * renaming hipcc function * added set runtime lib function * add set hipcc runtime library function * format code
1 parent 0c8028e commit 9939d0b

File tree

2 files changed

+67
-26
lines changed

2 files changed

+67
-26
lines changed

torch/utils/cpp_extension.py

Lines changed: 64 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,6 @@ def _join_rocm_home(*paths) -> str:
173173
if ROCM_HOME is None:
174174
raise OSError('ROCM_HOME environment variable is not set. '
175175
'Please set it to your ROCm install root.')
176-
elif IS_WINDOWS:
177-
raise OSError('Building PyTorch extensions using '
178-
'ROCm and Windows is not supported.')
179176
return os.path.join(ROCM_HOME, *paths)
180177

181178
def _join_sycl_home(*paths) -> str:
@@ -270,12 +267,14 @@ def _join_sycl_home(*paths) -> str:
270267
]
271268

272269
COMMON_HIP_FLAGS = [
273-
'-fPIC',
274270
'-D__HIP_PLATFORM_AMD__=1',
275271
'-DUSE_ROCM=1',
276272
'-DHIPBLAS_V2',
277273
]
278274

275+
if not IS_WINDOWS:
276+
COMMON_HIP_FLAGS.append('-fPIC')
277+
279278
COMMON_HIPCC_FLAGS = [
280279
'-DCUDA_HAS_FP16=1',
281280
'-D__HIP_NO_HALF_OPERATORS__=1',
@@ -513,6 +512,19 @@ def _check_cuda_version(compiler_name: str, compiler_version: TorchVersion) -> N
513512
f'Please make sure to use an adequate version of {compiler_name} ({version_bound_str}).'
514513
)
515514

515+
# Specify Visual Studio C runtime library for hipcc
516+
def _set_hipcc_runtime_lib(is_standalone):
517+
debug = os.getenv("DEBUG")
518+
if is_standalone:
519+
if debug:
520+
COMMON_HIP_FLAGS.append('-fms-runtime-lib=static_dbg')
521+
else:
522+
COMMON_HIP_FLAGS.append('-fms-runtime-lib=static')
523+
else:
524+
if debug:
525+
COMMON_HIP_FLAGS.append('-fms-runtime-lib=dll_dbg')
526+
else:
527+
COMMON_HIP_FLAGS.append('-fms-runtime-lib=dll')
516528

517529
def _append_sycl_std_if_no_std_present(cflags):
518530
if not any(flag.startswith('-sycl-std=') for flag in cflags):
@@ -833,6 +845,9 @@ def unix_wrap_ninja_compile(sources,
833845
def win_cuda_flags(cflags):
834846
return (COMMON_NVCC_FLAGS +
835847
cflags + _get_cuda_arch_flags(cflags))
848+
849+
def win_hip_flags(cflags):
850+
return (COMMON_HIPCC_FLAGS + COMMON_HIP_FLAGS + cflags + _get_rocm_arch_flags(cflags))
836851

837852
def win_wrap_single_compile(sources,
838853
output_dir=None,
@@ -870,19 +885,25 @@ def spawn(cmd):
870885
src = src_list[0]
871886
obj = obj_list[0]
872887
if _is_cuda_file(src):
873-
nvcc = _join_cuda_home('bin', 'nvcc')
888+
if IS_HIP_EXTENSION:
889+
nvcc = _get_hipcc_path()
890+
else:
891+
nvcc = _join_cuda_home('bin', 'nvcc')
874892
if isinstance(self.cflags, dict):
875893
cflags = self.cflags['nvcc']
876894
elif isinstance(self.cflags, list):
877895
cflags = self.cflags
878896
else:
879897
cflags = []
880898

881-
cflags = win_cuda_flags(cflags) + ['-std=c++17', '--use-local-env']
899+
if IS_HIP_EXTENSION:
900+
cflags = win_hip_flags(cflags)
901+
else:
902+
cflags = win_cuda_flags(cflags) + ['-std=c++17', '--use-local-env']
903+
for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS:
904+
cflags = ['-Xcudafe', '--diag_suppress=' + ignore_warning] + cflags
882905
for flag in COMMON_MSVC_FLAGS:
883906
cflags = ['-Xcompiler', flag] + cflags
884-
for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS:
885-
cflags = ['-Xcudafe', '--diag_suppress=' + ignore_warning] + cflags
886907
cmd = [nvcc, '-c', src, '-o', obj] + include_list + cflags
887908
elif isinstance(self.cflags, dict):
888909
cflags = COMMON_MSVC_FLAGS + self.cflags['cxx']
@@ -911,7 +932,6 @@ def win_wrap_ninja_compile(sources,
911932
extra_preargs=None,
912933
extra_postargs=None,
913934
depends=None):
914-
915935
if not self.compiler.initialized:
916936
self.compiler.initialize()
917937
output_dir = os.path.abspath(output_dir)
@@ -928,14 +948,20 @@ def win_wrap_ninja_compile(sources,
928948
self.compiler._setup_compile(output_dir, macros,
929949
include_dirs, sources,
930950
depends, extra_postargs)
951+
# Replace space with \ when using hipcc (hipcc passes includes to clang without ""s so clang sees space in include paths as new argument)
952+
if IS_HIP_EXTENSION:
953+
pp_opts = ["-I{}".format(s[2:].replace(" ", "\\")) if s.startswith('-I') else s for s in pp_opts]
931954
common_cflags = extra_preargs or []
932955
cflags = []
933956
if debug:
934957
cflags.extend(self.compiler.compile_options_debug)
935958
else:
936959
cflags.extend(self.compiler.compile_options)
937-
common_cflags.extend(COMMON_MSVC_FLAGS)
938-
cflags = cflags + common_cflags + pp_opts
960+
cflags = cflags + common_cflags + pp_opts + COMMON_MSVC_FLAGS
961+
if IS_HIP_EXTENSION:
962+
common_cflags.extend(COMMON_HIP_FLAGS)
963+
else:
964+
common_cflags.extend(COMMON_MSVC_FLAGS)
939965
with_cuda = any(map(_is_cuda_file, sources))
940966

941967
# extra_postargs can be either:
@@ -945,25 +971,31 @@ def win_wrap_ninja_compile(sources,
945971
post_cflags = extra_postargs['cxx']
946972
else:
947973
post_cflags = list(extra_postargs)
974+
if IS_HIP_EXTENSION:
975+
post_cflags = COMMON_HIP_FLAGS + post_cflags
948976
append_std17_if_no_std_present(post_cflags)
949977

950978
cuda_post_cflags = None
951979
cuda_cflags = None
952980
if with_cuda:
953-
cuda_cflags = ['-std=c++17', '--use-local-env']
981+
cuda_cflags = ['-std=c++17']
954982
for common_cflag in common_cflags:
955983
cuda_cflags.append('-Xcompiler')
956984
cuda_cflags.append(common_cflag)
957-
for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS:
958-
cuda_cflags.append('-Xcudafe')
959-
cuda_cflags.append('--diag_suppress=' + ignore_warning)
985+
if not IS_HIP_EXTENSION:
986+
cuda_cflags.append('--use-local-env')
987+
for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS:
988+
cuda_cflags.append('-Xcudafe')
989+
cuda_cflags.append('--diag_suppress=' + ignore_warning)
960990
cuda_cflags.extend(pp_opts)
961991
if isinstance(extra_postargs, dict):
962992
cuda_post_cflags = extra_postargs['nvcc']
963993
else:
964994
cuda_post_cflags = list(extra_postargs)
965-
cuda_post_cflags = win_cuda_flags(cuda_post_cflags)
966-
995+
if IS_HIP_EXTENSION:
996+
cuda_post_cflags = win_hip_flags(cuda_post_cflags)
997+
else:
998+
cuda_post_cflags = win_cuda_flags(cuda_post_cflags)
967999
cflags = _nt_quote_args(cflags)
9681000
post_cflags = _nt_quote_args(post_cflags)
9691001
if with_cuda:
@@ -992,7 +1024,6 @@ def win_wrap_ninja_compile(sources,
9921024

9931025
# Return *all* object filenames, not just the ones we just built.
9941026
return objects
995-
9961027
# Monkey-patch the _compile or compile method.
9971028
# https://github.com/python/cpython/blob/dc0284ee8f7a270b6005467f26d8e5773d76e959/Lib/distutils/ccompiler.py#L511
9981029
if self.compiler.compiler_type == 'msvc':
@@ -2070,7 +2101,12 @@ def _jit_compile(name,
20702101

20712102
return _import_module_from_library(name, build_directory, is_python_module)
20722103

2073-
2104+
def _get_hipcc_path():
2105+
if IS_WINDOWS:
2106+
return _join_rocm_home('bin', 'hipcc.bat')
2107+
else:
2108+
return _join_rocm_home('bin', 'hipcc')
2109+
20742110
def _write_ninja_file_and_compile_objects(
20752111
sources: list[str],
20762112
objects,
@@ -2473,6 +2509,7 @@ def _run_ninja_build(build_directory: str, verbose: bool, error_prefix: str) ->
24732509
stdout_fileno = 1
24742510
subprocess.run(
24752511
command,
2512+
shell=IS_WINDOWS and IS_HIP_EXTENSION,
24762513
stdout=stdout_fileno if verbose else subprocess.PIPE,
24772514
stderr=subprocess.STDOUT,
24782515
cwd=build_directory,
@@ -2567,7 +2604,8 @@ def _write_ninja_file_to_build_library(path,
25672604
common_cflags += [f"{x}" for x in _get_glibcxx_abi_build_flags()]
25682605

25692606
if IS_WINDOWS:
2570-
cflags = common_cflags + COMMON_MSVC_FLAGS + ['/std:c++17'] + extra_cflags
2607+
cflags = common_cflags + ['/std:c++17'] + extra_cflags
2608+
clfags += COMMON_HIP_FLAGS if IS_HIP_EXTENSION else COMMON_MSVC_FLAGS
25712609
cflags = _nt_quote_args(cflags)
25722610
else:
25732611
cflags = common_cflags + ['-fPIC', '-std=c++17'] + extra_cflags
@@ -2626,6 +2664,7 @@ def object_file_path(source_file: str) -> str:
26262664

26272665
objects = [object_file_path(src) for src in sources]
26282666
ldflags = ([] if is_standalone else [SHARED_FLAG]) + extra_ldflags
2667+
_set_hipcc_runtime_lib(is_standalone)
26292668

26302669
# The darwin linker needs explicit consent to ignore unresolved symbols.
26312670
if IS_MACOS:
@@ -2718,7 +2757,7 @@ def sanitize_flags(flags):
27182757
nvcc = os.getenv("PYTORCH_NVCC") # user can set nvcc compiler with ccache using the environment variable here
27192758
else:
27202759
if IS_HIP_EXTENSION:
2721-
nvcc = _join_rocm_home('bin', 'hipcc')
2760+
nvcc = _get_hipcc_path()
27222761
else:
27232762
nvcc = _join_cuda_home('bin', 'nvcc')
27242763
config.append(f'nvcc = {nvcc}')
@@ -2747,9 +2786,11 @@ def sanitize_flags(flags):
27472786
# See https://ninja-build.org/build.ninja.html for reference.
27482787
compile_rule = ['rule compile']
27492788
if IS_WINDOWS:
2789+
compiler_name = "$cxx" if IS_HIP_EXTENSION else "cl"
27502790
compile_rule.append(
2751-
' command = cl /showIncludes $cflags -c $in /Fo$out $post_cflags')
2752-
compile_rule.append(' deps = msvc')
2791+
f' command = {compiler_name} /showIncludes $cflags -c $in /Fo$out $post_cflags')
2792+
if not IS_HIP_EXTENSION:
2793+
compile_rule.append(' deps = msvc')
27532794
else:
27542795
compile_rule.append(
27552796
' command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags')

torch/utils/hipify/hipify_python.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@ def __exit__(self, type, value, traceback):
139139
for d in self.dirs_to_clean[::-1]:
140140
os.rmdir(d)
141141

142-
143142
# Follow UNIX convention for paths to use '/' instead of '\\' on Windows
144143
def _to_unix_path(path: str) -> str:
145144
return path.replace(os.sep, '/')
@@ -830,6 +829,7 @@ def preprocessor(
830829
show_progress: bool) -> HipifyResult:
831830
""" Executes the CUDA -> HIP conversion on the specified file. """
832831
fin_path = os.path.abspath(os.path.join(output_directory, filepath))
832+
filepath = _to_unix_path(filepath)
833833
hipify_result = HIPIFY_FINAL_RESULT[fin_path]
834834
if filepath not in all_files:
835835
hipify_result.hipified_path = None
@@ -932,8 +932,8 @@ def repl(m):
932932
return templ.format(os.path.relpath(header_fout_path if header_fout_path is not None
933933
else header_filepath, header_dir))
934934
hipified_header_filepath = HIPIFY_FINAL_RESULT[header_filepath].hipified_path
935-
return templ.format(os.path.relpath(hipified_header_filepath if hipified_header_filepath is not None
936-
else header_filepath, header_dir))
935+
return templ.format(_to_unix_path(os.path.relpath(hipified_header_filepath if hipified_header_filepath is not None
936+
else header_filepath, header_dir)))
937937

938938
return m.group(0)
939939
return repl

0 commit comments

Comments
 (0)