Skip to content

Commit cc253d9

Browse files
authored
Merge pull request #63 from Jorghi12/fix_cmake
Fix the PyTorch HIP Build Configuration
2 parents 801c1cc + f72c013 commit cc253d9

File tree

5 files changed

+34
-11
lines changed

5 files changed

+34
-11
lines changed

.jenkins/pytorch/build.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then
4141

4242
# This environment variable enabled HCC Optimizations that speed up the linking stage.
4343
# https://github.com/RadeonOpenCompute/hcc#hcc-with-thinlto-linking
44-
# export KMTHINLTO=1
44+
export KMTHINLTO=1
4545

4646
sudo chown -R jenkins:jenkins /usr/local
4747
rm -rf "$(dirname "${BASH_SOURCE[0]}")/../../../pytorch_amd/" || true

caffe2/CMakeLists.txt

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,13 +261,24 @@ endif()
261261
# ---[ Caffe2 HIP sources.
262262
if(USE_ROCM)
263263
# Call again since Caffe2_HIP_INCLUDES is extended with ATen include dirs.
264-
IF(BUILD_ATEN)
265-
HIP_INCLUDE_DIRECTORIES(${Caffe2_HIP_INCLUDES})
266-
ENDIF()
264+
if(BUILD_ATEN)
265+
# Get Compile Definitions from the directory (FindHIP.CMake bug)
266+
get_directory_property(MY_DEFINITIONS COMPILE_DEFINITIONS)
267+
if(MY_DEFINITIONS)
268+
foreach(_item ${MY_DEFINITIONS})
269+
LIST(APPEND HIP_HCC_FLAGS "-D${_item}")
270+
endforeach()
271+
endif()
272+
273+
# Call again since Caffe2_HIP_INCLUDES is extended with ATen include dirs.
274+
hip_include_directories(${Caffe2_HIP_INCLUDES})
275+
endif()
267276
IF(BUILD_CAFFE2)
268277
set_source_files_properties(${Caffe2_HIP_SRCS} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
269278
ENDIF()
270-
hip_add_library(caffe2_hip ${Caffe2_HIP_SRCS})
279+
280+
# FindHIP.CMake checks if the SHARED flag is set and adds extra logic accordingly.
281+
hip_add_library(caffe2_hip SHARED ${Caffe2_HIP_SRCS})
271282

272283
# Since PyTorch files contain HIP headers, these flags are required for the necessary definitions to be added.
273284
set_target_properties(caffe2_hip PROPERTIES COMPILE_FLAGS ${HIP_HIPCC_FLAGS})

cmake/public/LoadHIP.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ IF(HIP_FOUND)
111111
set(CMAKE_HIP_ARCHIVE_CREATE ${CMAKE_CXX_ARCHIVE_CREATE})
112112
set(CMAKE_HIP_ARCHIVE_APPEND ${CMAKE_CXX_ARCHIVE_APPEND})
113113
set(CMAKE_HIP_ARCHIVE_FINISH ${CMAKE_CXX_ARCHIVE_FINISH})
114+
SET(CMAKE_HCC_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
115+
SET(CMAKE_HCC_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
114116
### Remove setting of Flags when FindHIP.CMake PR #558 is accepted.###
115117

116118
set(rocrand_DIR ${ROCRAND_PATH}/lib/cmake/rocrand)

tools/amd_build/build_pytorch_amd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
"""Requires the hipify-python.py script (https://github.com/ROCm-Developer-Tools/pyHIPIFY)."""
21
import shutil
32
import subprocess
43
import os
@@ -8,6 +7,7 @@
87

98
amd_build_dir = os.path.dirname(os.path.realpath(__file__))
109
proj_dir = os.path.dirname(os.path.dirname(amd_build_dir))
10+
1111
includes = [
1212
"aten/*",
1313
"torch/*"
@@ -16,7 +16,7 @@
1616
# List of operators currently disabled
1717
yaml_file = os.path.join(amd_build_dir, "disabled_features.yaml")
1818

19-
# Apply patch files.
19+
# Apply patch files in place.
2020
patch_folder = os.path.join(amd_build_dir, "patches")
2121
for filename in os.listdir(os.path.join(amd_build_dir, "patches")):
2222
subprocess.Popen(["git", "apply", os.path.join(patch_folder, filename)], cwd=proj_dir)

tools/setup_helpers/rocm.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
1-
from .env import check_env_flag
2-
# Check if ROCM is enabled
3-
USE_ROCM = check_env_flag('USE_ROCM')
4-
ROCM_HOME = "/opt/rocm"
1+
import os
2+
from .env import check_env_flag, check_negative_env_flag
3+
4+
# Get ROCm Home Path
5+
ROCM_HOME = os.getenv("ROCM_HOME", "/opt/rocm")
56
ROCM_VERSION = ""
7+
USE_ROCM = False
8+
9+
# Check if ROCm disabled.
10+
if check_negative_env_flag("USE_ROCM"):
11+
USE_ROCM = False
12+
else:
13+
# If ROCM home exists or we explicitly enable ROCm
14+
if os.path.exists(ROCM_HOME) or check_env_flag('USE_ROCM'):
15+
USE_ROCM = True

0 commit comments

Comments
 (0)