@@ -87,6 +87,14 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
87
87
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
88
88
"You can try commenting out this check (at your own risk)." )
89
89
90
+ def check_if_rocm_pytorch ():
91
+ is_rocm_pytorch = False
92
+ if torch .__version__ >= '1.5' :
93
+ from torch .utils .cpp_extension import ROCM_HOME
94
+ is_rocm_pytorch = True if ((torch .version .hip is not None ) and (ROCM_HOME is not None )) else False
95
+
96
+ return is_rocm_pytorch
97
+
90
98
# Set up macros for forward/backward compatibility hack around
91
99
# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e
92
100
# and
@@ -279,17 +287,28 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
279
287
from torch .utils .cpp_extension import BuildExtension
280
288
cmdclass ['build_ext' ] = BuildExtension
281
289
282
- if torch .utils .cpp_extension .CUDA_HOME is None :
290
+ is_rocm_pytorch = check_if_rocm_pytorch ()
291
+
292
+ if torch .utils .cpp_extension .CUDA_HOME is None and (not is_rocm_pytorch ):
283
293
raise RuntimeError ("--deprecated_fused_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc." )
284
294
else :
285
- ext_modules .append (
286
- CUDAExtension (name = 'fused_adam_cuda' ,
287
- sources = ['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp' ,
288
- 'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu' ],
289
- include_dirs = [os .path .join (this_dir , 'csrc' )],
290
- extra_compile_args = {'cxx' : ['-O3' ,] + version_dependent_macros ,
291
- 'nvcc' :['-O3' ,
292
- '--use_fast_math' ] + version_dependent_macros }))
295
+ if not is_rocm_pytorch :
296
+ ext_modules .append (
297
+ CUDAExtension (name = 'fused_adam_cuda' ,
298
+ sources = ['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp' ,
299
+ 'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu' ],
300
+ include_dirs = [os .path .join (this_dir , 'csrc' )],
301
+ extra_compile_args = {'cxx' : ['-O3' ,] + version_dependent_macros ,
302
+ 'nvcc' :['-O3' ,
303
+ '--use_fast_math' ] + version_dependent_macros }))
304
+ else :
305
+ print ("INFO: Building deprecated fused adam." )
306
+ ext_modules .append (
307
+ CUDAExtension (name = 'fused_adam_cuda' ,
308
+ sources = ['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp' ,
309
+ 'apex/contrib/csrc/optimizers/hip/fused_adam_hip_kernel.hip' ],
310
+ include_dirs = [os .path .join (this_dir , 'csrc/hip' )],
311
+ extra_compile_args = ['-O3' ] + version_dependent_macros ))
293
312
294
313
if "--deprecated_fused_lamb" in sys .argv :
295
314
from torch .utils .cpp_extension import CUDAExtension
@@ -298,18 +317,30 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
298
317
from torch .utils .cpp_extension import BuildExtension
299
318
cmdclass ['build_ext' ] = BuildExtension
300
319
301
- if torch .utils .cpp_extension .CUDA_HOME is None :
320
+ is_rocm_pytorch = check_if_rocm_pytorch ()
321
+
322
+ if torch .utils .cpp_extension .CUDA_HOME is None and (not is_rocm_pytorch ):
302
323
raise RuntimeError ("--deprecated_fused_lamb was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc." )
303
324
else :
304
- ext_modules .append (
305
- CUDAExtension (name = 'fused_lamb_cuda' ,
306
- sources = ['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp' ,
307
- 'apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu' ,
308
- 'csrc/multi_tensor_l2norm_kernel.cu' ],
309
- include_dirs = [os .path .join (this_dir , 'csrc' )],
310
- extra_compile_args = {'cxx' : ['-O3' ,] + version_dependent_macros ,
311
- 'nvcc' :['-O3' ,
312
- '--use_fast_math' ] + version_dependent_macros }))
325
+ if not is_rocm_pytorch :
326
+ ext_modules .append (
327
+ CUDAExtension (name = 'fused_lamb_cuda' ,
328
+ sources = ['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp' ,
329
+ 'apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu' ,
330
+ 'csrc/multi_tensor_l2norm_kernel.cu' ],
331
+ include_dirs = [os .path .join (this_dir , 'csrc' )],
332
+ extra_compile_args = {'cxx' : ['-O3' ,] + version_dependent_macros ,
333
+ 'nvcc' :['-O3' ,
334
+ '--use_fast_math' ] + version_dependent_macros }))
335
+ else :
336
+ print ("INFO: Building deprecated fused lamb." )
337
+ ext_modules .append (
338
+ CUDAExtension (name = 'fused_lamb_cuda' ,
339
+ sources = ['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp' ,
340
+ 'apex/contrib/csrc/optimizers/hip/fused_lamb_hip_kernel.hip' ,
341
+ 'csrc/hip/multi_tensor_l2norm_kernel.hip' ],
342
+ include_dirs = [os .path .join (this_dir , 'csrc/hip' )],
343
+ extra_compile_args = ['-O3' ] + version_dependent_macros ))
313
344
314
345
# Check, if ATen/CUDAGenerator.h is found, otherwise use the new ATen/CUDAGeneratorImpl.h, due to breaking change in https://github.com/pytorch/pytorch/pull/36026
315
346
generator_flag = []
0 commit comments