@@ -173,9 +173,6 @@ def _join_rocm_home(*paths) -> str:
173
173
if ROCM_HOME is None :
174
174
raise OSError ('ROCM_HOME environment variable is not set. '
175
175
'Please set it to your ROCm install root.' )
176
- elif IS_WINDOWS :
177
- raise OSError ('Building PyTorch extensions using '
178
- 'ROCm and Windows is not supported.' )
179
176
return os .path .join (ROCM_HOME , * paths )
180
177
181
178
def _join_sycl_home (* paths ) -> str :
@@ -270,12 +267,14 @@ def _join_sycl_home(*paths) -> str:
270
267
]
271
268
272
269
COMMON_HIP_FLAGS = [
273
- '-fPIC' ,
274
270
'-D__HIP_PLATFORM_AMD__=1' ,
275
271
'-DUSE_ROCM=1' ,
276
272
'-DHIPBLAS_V2' ,
277
273
]
278
274
275
+ if not IS_WINDOWS :
276
+ COMMON_HIP_FLAGS .append ('-fPIC' )
277
+
279
278
COMMON_HIPCC_FLAGS = [
280
279
'-DCUDA_HAS_FP16=1' ,
281
280
'-D__HIP_NO_HALF_OPERATORS__=1' ,
@@ -513,6 +512,19 @@ def _check_cuda_version(compiler_name: str, compiler_version: TorchVersion) -> N
513
512
f'Please make sure to use an adequate version of { compiler_name } ({ version_bound_str } ).'
514
513
)
515
514
515
+ # Specify Visual Studio C runtime library for hipcc
516
+ def _set_hipcc_runtime_lib (is_standalone ):
517
+ debug = os .getenv ("DEBUG" )
518
+ if is_standalone :
519
+ if debug :
520
+ COMMON_HIP_FLAGS .append ('-fms-runtime-lib=static_dbg' )
521
+ else :
522
+ COMMON_HIP_FLAGS .append ('-fms-runtime-lib=static' )
523
+ else :
524
+ if debug :
525
+ COMMON_HIP_FLAGS .append ('-fms-runtime-lib=dll_dbg' )
526
+ else :
527
+ COMMON_HIP_FLAGS .append ('-fms-runtime-lib=dll' )
516
528
517
529
def _append_sycl_std_if_no_std_present (cflags ):
518
530
if not any (flag .startswith ('-sycl-std=' ) for flag in cflags ):
@@ -833,6 +845,9 @@ def unix_wrap_ninja_compile(sources,
833
845
def win_cuda_flags (cflags ):
834
846
return (COMMON_NVCC_FLAGS +
835
847
cflags + _get_cuda_arch_flags (cflags ))
848
+
849
+ def win_hip_flags (cflags ):
850
+ return (COMMON_HIPCC_FLAGS + COMMON_HIP_FLAGS + cflags + _get_rocm_arch_flags (cflags ))
836
851
837
852
def win_wrap_single_compile (sources ,
838
853
output_dir = None ,
@@ -870,19 +885,25 @@ def spawn(cmd):
870
885
src = src_list [0 ]
871
886
obj = obj_list [0 ]
872
887
if _is_cuda_file (src ):
873
- nvcc = _join_cuda_home ('bin' , 'nvcc' )
888
+ if IS_HIP_EXTENSION :
889
+ nvcc = _get_hipcc_path ()
890
+ else :
891
+ nvcc = _join_cuda_home ('bin' , 'nvcc' )
874
892
if isinstance (self .cflags , dict ):
875
893
cflags = self .cflags ['nvcc' ]
876
894
elif isinstance (self .cflags , list ):
877
895
cflags = self .cflags
878
896
else :
879
897
cflags = []
880
898
881
- cflags = win_cuda_flags (cflags ) + ['-std=c++17' , '--use-local-env' ]
899
+ if IS_HIP_EXTENSION :
900
+ cflags = win_hip_flags (cflags )
901
+ else :
902
+ cflags = win_cuda_flags (cflags ) + ['-std=c++17' , '--use-local-env' ]
903
+ for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS :
904
+ cflags = ['-Xcudafe' , '--diag_suppress=' + ignore_warning ] + cflags
882
905
for flag in COMMON_MSVC_FLAGS :
883
906
cflags = ['-Xcompiler' , flag ] + cflags
884
- for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS :
885
- cflags = ['-Xcudafe' , '--diag_suppress=' + ignore_warning ] + cflags
886
907
cmd = [nvcc , '-c' , src , '-o' , obj ] + include_list + cflags
887
908
elif isinstance (self .cflags , dict ):
888
909
cflags = COMMON_MSVC_FLAGS + self .cflags ['cxx' ]
@@ -911,7 +932,6 @@ def win_wrap_ninja_compile(sources,
911
932
extra_preargs = None ,
912
933
extra_postargs = None ,
913
934
depends = None ):
914
-
915
935
if not self .compiler .initialized :
916
936
self .compiler .initialize ()
917
937
output_dir = os .path .abspath (output_dir )
@@ -928,14 +948,20 @@ def win_wrap_ninja_compile(sources,
928
948
self .compiler ._setup_compile (output_dir , macros ,
929
949
include_dirs , sources ,
930
950
depends , extra_postargs )
951
+ # Replace space with \ when using hipcc (hipcc passes includes to clang without ""s so clang sees space in include paths as new argument)
952
+ if IS_HIP_EXTENSION :
953
+ pp_opts = ["-I{}" .format (s [2 :].replace (" " , "\\ " )) if s .startswith ('-I' ) else s for s in pp_opts ]
931
954
common_cflags = extra_preargs or []
932
955
cflags = []
933
956
if debug :
934
957
cflags .extend (self .compiler .compile_options_debug )
935
958
else :
936
959
cflags .extend (self .compiler .compile_options )
937
- common_cflags .extend (COMMON_MSVC_FLAGS )
938
- cflags = cflags + common_cflags + pp_opts
960
+ cflags = cflags + common_cflags + pp_opts + COMMON_MSVC_FLAGS
961
+ if IS_HIP_EXTENSION :
962
+ common_cflags .extend (COMMON_HIP_FLAGS )
963
+ else :
964
+ common_cflags .extend (COMMON_MSVC_FLAGS )
939
965
with_cuda = any (map (_is_cuda_file , sources ))
940
966
941
967
# extra_postargs can be either:
@@ -945,25 +971,31 @@ def win_wrap_ninja_compile(sources,
945
971
post_cflags = extra_postargs ['cxx' ]
946
972
else :
947
973
post_cflags = list (extra_postargs )
974
+ if IS_HIP_EXTENSION :
975
+ post_cflags = COMMON_HIP_FLAGS + post_cflags
948
976
append_std17_if_no_std_present (post_cflags )
949
977
950
978
cuda_post_cflags = None
951
979
cuda_cflags = None
952
980
if with_cuda :
953
- cuda_cflags = ['-std=c++17' , '--use-local-env' ]
981
+ cuda_cflags = ['-std=c++17' ]
954
982
for common_cflag in common_cflags :
955
983
cuda_cflags .append ('-Xcompiler' )
956
984
cuda_cflags .append (common_cflag )
957
- for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS :
958
- cuda_cflags .append ('-Xcudafe' )
959
- cuda_cflags .append ('--diag_suppress=' + ignore_warning )
985
+ if not IS_HIP_EXTENSION :
986
+ cuda_cflags .append ('--use-local-env' )
987
+ for ignore_warning in MSVC_IGNORE_CUDAFE_WARNINGS :
988
+ cuda_cflags .append ('-Xcudafe' )
989
+ cuda_cflags .append ('--diag_suppress=' + ignore_warning )
960
990
cuda_cflags .extend (pp_opts )
961
991
if isinstance (extra_postargs , dict ):
962
992
cuda_post_cflags = extra_postargs ['nvcc' ]
963
993
else :
964
994
cuda_post_cflags = list (extra_postargs )
965
- cuda_post_cflags = win_cuda_flags (cuda_post_cflags )
966
-
995
+ if IS_HIP_EXTENSION :
996
+ cuda_post_cflags = win_hip_flags (cuda_post_cflags )
997
+ else :
998
+ cuda_post_cflags = win_cuda_flags (cuda_post_cflags )
967
999
cflags = _nt_quote_args (cflags )
968
1000
post_cflags = _nt_quote_args (post_cflags )
969
1001
if with_cuda :
@@ -992,7 +1024,6 @@ def win_wrap_ninja_compile(sources,
992
1024
993
1025
# Return *all* object filenames, not just the ones we just built.
994
1026
return objects
995
-
996
1027
# Monkey-patch the _compile or compile method.
997
1028
# https://github.com/python/cpython/blob/dc0284ee8f7a270b6005467f26d8e5773d76e959/Lib/distutils/ccompiler.py#L511
998
1029
if self .compiler .compiler_type == 'msvc' :
@@ -2070,7 +2101,12 @@ def _jit_compile(name,
2070
2101
2071
2102
return _import_module_from_library (name , build_directory , is_python_module )
2072
2103
2073
-
2104
+ def _get_hipcc_path ():
2105
+ if IS_WINDOWS :
2106
+ return _join_rocm_home ('bin' , 'hipcc.bat' )
2107
+ else :
2108
+ return _join_rocm_home ('bin' , 'hipcc' )
2109
+
2074
2110
def _write_ninja_file_and_compile_objects (
2075
2111
sources : list [str ],
2076
2112
objects ,
@@ -2473,6 +2509,7 @@ def _run_ninja_build(build_directory: str, verbose: bool, error_prefix: str) ->
2473
2509
stdout_fileno = 1
2474
2510
subprocess .run (
2475
2511
command ,
2512
+ shell = IS_WINDOWS and IS_HIP_EXTENSION ,
2476
2513
stdout = stdout_fileno if verbose else subprocess .PIPE ,
2477
2514
stderr = subprocess .STDOUT ,
2478
2515
cwd = build_directory ,
@@ -2567,7 +2604,8 @@ def _write_ninja_file_to_build_library(path,
2567
2604
common_cflags += [f"{ x } " for x in _get_glibcxx_abi_build_flags ()]
2568
2605
2569
2606
if IS_WINDOWS :
2570
- cflags = common_cflags + COMMON_MSVC_FLAGS + ['/std:c++17' ] + extra_cflags
2607
+ cflags = common_cflags + ['/std:c++17' ] + extra_cflags
2608
+ clfags += COMMON_HIP_FLAGS if IS_HIP_EXTENSION else COMMON_MSVC_FLAGS
2571
2609
cflags = _nt_quote_args (cflags )
2572
2610
else :
2573
2611
cflags = common_cflags + ['-fPIC' , '-std=c++17' ] + extra_cflags
@@ -2626,6 +2664,7 @@ def object_file_path(source_file: str) -> str:
2626
2664
2627
2665
objects = [object_file_path (src ) for src in sources ]
2628
2666
ldflags = ([] if is_standalone else [SHARED_FLAG ]) + extra_ldflags
2667
+ _set_hipcc_runtime_lib (is_standalone )
2629
2668
2630
2669
# The darwin linker needs explicit consent to ignore unresolved symbols.
2631
2670
if IS_MACOS :
@@ -2718,7 +2757,7 @@ def sanitize_flags(flags):
2718
2757
nvcc = os .getenv ("PYTORCH_NVCC" ) # user can set nvcc compiler with ccache using the environment variable here
2719
2758
else :
2720
2759
if IS_HIP_EXTENSION :
2721
- nvcc = _join_rocm_home ( 'bin' , 'hipcc' )
2760
+ nvcc = _get_hipcc_path ( )
2722
2761
else :
2723
2762
nvcc = _join_cuda_home ('bin' , 'nvcc' )
2724
2763
config .append (f'nvcc = { nvcc } ' )
@@ -2747,9 +2786,11 @@ def sanitize_flags(flags):
2747
2786
# See https://ninja-build.org/build.ninja.html for reference.
2748
2787
compile_rule = ['rule compile' ]
2749
2788
if IS_WINDOWS :
2789
+ compiler_name = "$cxx" if IS_HIP_EXTENSION else "cl"
2750
2790
compile_rule .append (
2751
- ' command = cl /showIncludes $cflags -c $in /Fo$out $post_cflags' )
2752
- compile_rule .append (' deps = msvc' )
2791
+ f' command = { compiler_name } /showIncludes $cflags -c $in /Fo$out $post_cflags' )
2792
+ if not IS_HIP_EXTENSION :
2793
+ compile_rule .append (' deps = msvc' )
2753
2794
else :
2754
2795
compile_rule .append (
2755
2796
' command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags' )
0 commit comments