Skip to content

Commit 17fbbf9

Browse files
authored
[contrib] Support optimizers on rocm. (pytorch#33)
* enable deprecated fused adam optimizer * enable deprecated fused lamb * reset the compiler arguments * syntax error * aligning the compiler arguments
1 parent d2f6d04 commit 17fbbf9

File tree

1 file changed

+50
-19
lines changed

1 file changed

+50
-19
lines changed

setup.py

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,14 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
8787
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
8888
"You can try commenting out this check (at your own risk).")
8989

90+
def check_if_rocm_pytorch():
91+
is_rocm_pytorch = False
92+
if torch.__version__ >= '1.5':
93+
from torch.utils.cpp_extension import ROCM_HOME
94+
is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False
95+
96+
return is_rocm_pytorch
97+
9098
# Set up macros for forward/backward compatibility hack around
9199
# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e
92100
# and
@@ -279,17 +287,28 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
279287
from torch.utils.cpp_extension import BuildExtension
280288
cmdclass['build_ext'] = BuildExtension
281289

282-
if torch.utils.cpp_extension.CUDA_HOME is None:
290+
is_rocm_pytorch = check_if_rocm_pytorch()
291+
292+
if torch.utils.cpp_extension.CUDA_HOME is None and (not is_rocm_pytorch):
283293
raise RuntimeError("--deprecated_fused_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
284294
else:
285-
ext_modules.append(
286-
CUDAExtension(name='fused_adam_cuda',
287-
sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp',
288-
'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'],
289-
include_dirs=[os.path.join(this_dir, 'csrc')],
290-
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
291-
'nvcc':['-O3',
292-
'--use_fast_math'] + version_dependent_macros}))
295+
if not is_rocm_pytorch:
296+
ext_modules.append(
297+
CUDAExtension(name='fused_adam_cuda',
298+
sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp',
299+
'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'],
300+
include_dirs=[os.path.join(this_dir, 'csrc')],
301+
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
302+
'nvcc':['-O3',
303+
'--use_fast_math'] + version_dependent_macros}))
304+
else:
305+
print ("INFO: Building deprecated fused adam.")
306+
ext_modules.append(
307+
CUDAExtension(name='fused_adam_cuda',
308+
sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp',
309+
'apex/contrib/csrc/optimizers/hip/fused_adam_hip_kernel.hip'],
310+
include_dirs=[os.path.join(this_dir, 'csrc/hip')],
311+
extra_compile_args=['-O3'] + version_dependent_macros))
293312

294313
if "--deprecated_fused_lamb" in sys.argv:
295314
from torch.utils.cpp_extension import CUDAExtension
@@ -298,18 +317,30 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
298317
from torch.utils.cpp_extension import BuildExtension
299318
cmdclass['build_ext'] = BuildExtension
300319

301-
if torch.utils.cpp_extension.CUDA_HOME is None:
320+
is_rocm_pytorch = check_if_rocm_pytorch()
321+
322+
if torch.utils.cpp_extension.CUDA_HOME is None and (not is_rocm_pytorch):
302323
raise RuntimeError("--deprecated_fused_lamb was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
303324
else:
304-
ext_modules.append(
305-
CUDAExtension(name='fused_lamb_cuda',
306-
sources=['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp',
307-
'apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu',
308-
'csrc/multi_tensor_l2norm_kernel.cu'],
309-
include_dirs=[os.path.join(this_dir, 'csrc')],
310-
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
311-
'nvcc':['-O3',
312-
'--use_fast_math'] + version_dependent_macros}))
325+
if not is_rocm_pytorch:
326+
ext_modules.append(
327+
CUDAExtension(name='fused_lamb_cuda',
328+
sources=['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp',
329+
'apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu',
330+
'csrc/multi_tensor_l2norm_kernel.cu'],
331+
include_dirs=[os.path.join(this_dir, 'csrc')],
332+
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
333+
'nvcc':['-O3',
334+
'--use_fast_math'] + version_dependent_macros}))
335+
else:
336+
print ("INFO: Building deprecated fused lamb.")
337+
ext_modules.append(
338+
CUDAExtension(name='fused_lamb_cuda',
339+
sources=['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp',
340+
'apex/contrib/csrc/optimizers/hip/fused_lamb_hip_kernel.hip',
341+
'csrc/hip/multi_tensor_l2norm_kernel.hip'],
342+
include_dirs=[os.path.join(this_dir, 'csrc/hip')],
343+
extra_compile_args=['-O3'] + version_dependent_macros))
313344

314345
# Check, if ATen/CUDAGenerator.h is found, otherwise use the new ATen/CUDAGeneratorImpl.h, due to breaking change in https://github.com/pytorch/pytorch/pull/36026
315346
generator_flag = []

0 commit comments

Comments
 (0)