@@ -173,9 +173,6 @@ def _join_rocm_home(*paths) -> str:
173
173
if ROCM_HOME is None :
174
174
raise OSError ('ROCM_HOME environment variable is not set. '
175
175
'Please set it to your ROCm install root.' )
176
- elif IS_WINDOWS :
177
- raise OSError ('Building PyTorch extensions using '
178
- 'ROCm and Windows is not supported.' )
179
176
return os .path .join (ROCM_HOME , * paths )
180
177
181
178
def _join_sycl_home (* paths ) -> str :
@@ -270,12 +267,14 @@ def _join_sycl_home(*paths) -> str:
270
267
]
271
268
272
269
COMMON_HIP_FLAGS = [
273
- '-fPIC' ,
274
270
'-D__HIP_PLATFORM_AMD__=1' ,
275
271
'-DUSE_ROCM=1' ,
276
272
'-DHIPBLAS_V2' ,
277
273
]
278
274
275
+ if not IS_WINDOWS :
276
+ COMMON_HIP_FLAGS .append ('-fPIC' )
277
+
279
278
COMMON_HIPCC_FLAGS = [
280
279
'-DCUDA_HAS_FP16=1' ,
281
280
'-D__HIP_NO_HALF_OPERATORS__=1' ,
@@ -511,6 +510,19 @@ def _check_cuda_version(compiler_name: str, compiler_version: TorchVersion) -> N
511
510
f'Please make sure to use an adequate version of { compiler_name } ({ version_bound_str } ).'
512
511
)
513
512
513
+ # Specify Visual Studio C runtime library for hipcc
514
+ def _set_hipcc_runtime_lib (is_standalone ):
515
+ debug = os .getenv ("DEBUG" )
516
+ if is_standalone :
517
+ if debug :
518
+ COMMON_HIP_FLAGS .append ('-fms-runtime-lib=static_dbg' )
519
+ else :
520
+ COMMON_HIP_FLAGS .append ('-fms-runtime-lib=static' )
521
+ else :
522
+ if debug :
523
+ COMMON_HIP_FLAGS .append ('-fms-runtime-lib=dll_dbg' )
524
+ else :
525
+ COMMON_HIP_FLAGS .append ('-fms-runtime-lib=dll' )
514
526
515
527
def _append_sycl_std_if_no_std_present (cflags ):
516
528
if not any (flag .startswith ('-sycl-std=' ) for flag in cflags ):
@@ -831,6 +843,9 @@ def unix_wrap_ninja_compile(sources,
831
843
def win_cuda_flags (cflags ):
832
844
return (COMMON_NVCC_FLAGS +
833
845
cflags + _get_cuda_arch_flags (cflags ))
846
+
847
+ def win_hip_flags (cflags ):
848
+ return (COMMON_HIPCC_FLAGS + COMMON_HIP_FLAGS + cflags + _get_rocm_arch_flags (cflags ))
834
849
835
850
def win_wrap_single_compile (sources ,
836
851
output_dir = None ,
@@ -868,19 +883,25 @@ def spawn(cmd):
868
883
src = src_list [0 ]
869
884
obj = obj_list [0 ]
870
885
if _is_cuda_file (src ):
871
- nvcc = _join_cuda_home ('bin' , 'nvcc' )
886
+ if IS_HIP_EXTENSION :
887
+ nvcc = _get_hipcc_path ()
888
+ else :
889
+ nvcc = _join_cuda_home ('bin' , 'nvcc' )
872
890
if isinstance (self .cflags , dict ):
873
891
cflags = self .cflags ['nvcc' ]
874
892
elif isinstance (self .cflags , list ):
875
893
cflags = self .cflags
876
894
else :
877
895
cflags = []
878
896
879
- cflags = win_cuda_flags (cflags ) + ['-std=c++17' , '--use-local-env' ]
897
+ if IS_HIP_EXTENSION :
898
+ cflags = win_hip_flags (cflags )
899
+ else :
900
+ cflags = win_cuda_flags (cflags ) + ['-std=c++17' , '--use-local-env' ]
901
+ for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS :
902
+ cflags = ['-Xcudafe' , '--diag_suppress=' + ignore_warning ] + cflags
880
903
for flag in COMMON_MSVC_FLAGS :
881
904
cflags = ['-Xcompiler' , flag ] + cflags
882
- for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS :
883
- cflags = ['-Xcudafe' , '--diag_suppress=' + ignore_warning ] + cflags
884
905
cmd = [nvcc , '-c' , src , '-o' , obj ] + include_list + cflags
885
906
elif isinstance (self .cflags , dict ):
886
907
cflags = COMMON_MSVC_FLAGS + self .cflags ['cxx' ]
@@ -909,7 +930,6 @@ def win_wrap_ninja_compile(sources,
909
930
extra_preargs = None ,
910
931
extra_postargs = None ,
911
932
depends = None ):
912
-
913
933
if not self .compiler .initialized :
914
934
self .compiler .initialize ()
915
935
output_dir = os .path .abspath (output_dir )
@@ -926,14 +946,20 @@ def win_wrap_ninja_compile(sources,
926
946
self .compiler ._setup_compile (output_dir , macros ,
927
947
include_dirs , sources ,
928
948
depends , extra_postargs )
949
+ # Replace space with \ when using hipcc (hipcc passes includes to clang without ""s so clang sees space in include paths as new argument)
950
+ if IS_HIP_EXTENSION :
951
+ pp_opts = ["-I{}" .format (s [2 :].replace (" " , "\\ " )) if s .startswith ('-I' ) else s for s in pp_opts ]
929
952
common_cflags = extra_preargs or []
930
953
cflags = []
931
954
if debug :
932
955
cflags .extend (self .compiler .compile_options_debug )
933
956
else :
934
957
cflags .extend (self .compiler .compile_options )
935
- common_cflags .extend (COMMON_MSVC_FLAGS )
936
- cflags = cflags + common_cflags + pp_opts
958
+ cflags = cflags + common_cflags + pp_opts + COMMON_MSVC_FLAGS
959
+ if IS_HIP_EXTENSION :
960
+ common_cflags .extend (COMMON_HIP_FLAGS )
961
+ else :
962
+ common_cflags .extend (COMMON_MSVC_FLAGS )
937
963
with_cuda = any (map (_is_cuda_file , sources ))
938
964
939
965
# extra_postargs can be either:
@@ -943,25 +969,31 @@ def win_wrap_ninja_compile(sources,
943
969
post_cflags = extra_postargs ['cxx' ]
944
970
else :
945
971
post_cflags = list (extra_postargs )
972
+ if IS_HIP_EXTENSION :
973
+ post_cflags = COMMON_HIP_FLAGS + post_cflags
946
974
append_std17_if_no_std_present (post_cflags )
947
975
948
976
cuda_post_cflags = None
949
977
cuda_cflags = None
950
978
if with_cuda :
951
- cuda_cflags = ['-std=c++17' , '--use-local-env' ]
979
+ cuda_cflags = ['-std=c++17' ]
952
980
for common_cflag in common_cflags :
953
981
cuda_cflags .append ('-Xcompiler' )
954
982
cuda_cflags .append (common_cflag )
955
- for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS :
956
- cuda_cflags .append ('-Xcudafe' )
957
- cuda_cflags .append ('--diag_suppress=' + ignore_warning )
983
+ if not IS_HIP_EXTENSION :
984
+ cuda_cflags .append ('--use-local-env' )
985
+ for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS :
986
+ cuda_cflags .append ('-Xcudafe' )
987
+ cuda_cflags .append ('--diag_suppress=' + ignore_warning )
958
988
cuda_cflags .extend (pp_opts )
959
989
if isinstance (extra_postargs , dict ):
960
990
cuda_post_cflags = extra_postargs ['nvcc' ]
961
991
else :
962
992
cuda_post_cflags = list (extra_postargs )
963
- cuda_post_cflags = win_cuda_flags (cuda_post_cflags )
964
-
993
+ if IS_HIP_EXTENSION :
994
+ cuda_post_cflags = win_hip_flags (cuda_post_cflags )
995
+ else :
996
+ cuda_post_cflags = win_cuda_flags (cuda_post_cflags )
965
997
cflags = _nt_quote_args (cflags )
966
998
post_cflags = _nt_quote_args (post_cflags )
967
999
if with_cuda :
@@ -990,7 +1022,6 @@ def win_wrap_ninja_compile(sources,
990
1022
991
1023
# Return *all* object filenames, not just the ones we just built.
992
1024
return objects
993
-
994
1025
# Monkey-patch the _compile or compile method.
995
1026
# https://github.com/python/cpython/blob/dc0284ee8f7a270b6005467f26d8e5773d76e959/Lib/distutils/ccompiler.py#L511
996
1027
if self .compiler .compiler_type == 'msvc' :
@@ -2076,7 +2107,12 @@ def _jit_compile(name,
2076
2107
2077
2108
return _import_module_from_library (name , build_directory , is_python_module )
2078
2109
2079
-
2110
+ def _get_hipcc_path ():
2111
+ if IS_WINDOWS :
2112
+ return _join_rocm_home ('bin' , 'hipcc.bat' )
2113
+ else :
2114
+ return _join_rocm_home ('bin' , 'hipcc' )
2115
+
2080
2116
def _write_ninja_file_and_compile_objects (
2081
2117
sources : list [str ],
2082
2118
objects ,
@@ -2479,6 +2515,7 @@ def _run_ninja_build(build_directory: str, verbose: bool, error_prefix: str) ->
2479
2515
stdout_fileno = 1
2480
2516
subprocess .run (
2481
2517
command ,
2518
+ shell = IS_WINDOWS and IS_HIP_EXTENSION ,
2482
2519
stdout = stdout_fileno if verbose else subprocess .PIPE ,
2483
2520
stderr = subprocess .STDOUT ,
2484
2521
cwd = build_directory ,
@@ -2573,7 +2610,8 @@ def _write_ninja_file_to_build_library(path,
2573
2610
common_cflags += [f"{ x } " for x in _get_glibcxx_abi_build_flags ()]
2574
2611
2575
2612
if IS_WINDOWS :
2576
- cflags = common_cflags + COMMON_MSVC_FLAGS + ['/std:c++17' ] + extra_cflags
2613
+ cflags = common_cflags + ['/std:c++17' ] + extra_cflags
2614
+ clfags += COMMON_HIP_FLAGS if IS_HIP_EXTENSION else COMMON_MSVC_FLAGS
2577
2615
cflags = _nt_quote_args (cflags )
2578
2616
else :
2579
2617
cflags = common_cflags + ['-fPIC' , '-std=c++17' ] + extra_cflags
@@ -2632,6 +2670,7 @@ def object_file_path(source_file: str) -> str:
2632
2670
2633
2671
objects = [object_file_path (src ) for src in sources ]
2634
2672
ldflags = ([] if is_standalone else [SHARED_FLAG ]) + extra_ldflags
2673
+ _set_hipcc_runtime_lib (is_standalone )
2635
2674
2636
2675
# The darwin linker needs explicit consent to ignore unresolved symbols.
2637
2676
if IS_MACOS :
@@ -2724,7 +2763,7 @@ def sanitize_flags(flags):
2724
2763
nvcc = os .getenv ("PYTORCH_NVCC" ) # user can set nvcc compiler with ccache using the environment variable here
2725
2764
else :
2726
2765
if IS_HIP_EXTENSION :
2727
- nvcc = _join_rocm_home ( 'bin' , 'hipcc' )
2766
+ nvcc = _get_hipcc_path ( )
2728
2767
else :
2729
2768
nvcc = _join_cuda_home ('bin' , 'nvcc' )
2730
2769
config .append (f'nvcc = { nvcc } ' )
@@ -2753,9 +2792,11 @@ def sanitize_flags(flags):
2753
2792
# See https://ninja-build.org/build.ninja.html for reference.
2754
2793
compile_rule = ['rule compile' ]
2755
2794
if IS_WINDOWS :
2795
+ compiler_name = "$cxx" if IS_HIP_EXTENSION else "cl"
2756
2796
compile_rule .append (
2757
- ' command = cl /showIncludes $cflags -c $in /Fo$out $post_cflags' )
2758
- compile_rule .append (' deps = msvc' )
2797
+ f' command = { compiler_name } /showIncludes $cflags -c $in /Fo$out $post_cflags' )
2798
+ if not IS_HIP_EXTENSION :
2799
+ compile_rule .append (' deps = msvc' )
2759
2800
else :
2760
2801
compile_rule .append (
2761
2802
' command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags' )
0 commit comments