Skip to content

Commit 5fd670e

Browse files
xinyazhangpytorchmergebot
authored andcommitted
[ROCM] Properly disable Flash Attention/Efficient Attention with environment variables (pytorch#133866)
Now `USE_FLASH_ATTENTION=0 USE_MEM_EFF_ATTENTION=0 python setup.py` can compile correctly Fixes pytorch#125230 Pull Request resolved: pytorch#133866 Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily, https://github.com/malfet
1 parent 5b392d2 commit 5fd670e

File tree

3 files changed

+21
-4
lines changed

3 files changed

+21
-4
lines changed

CMakeLists.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,16 @@ cmake_dependent_option(
883883
Will be disabled if not supported by the platform" ON
884884
"USE_CUDA OR USE_ROCM" OFF)
885885

886+
#
887+
# Cannot be put into Dependencies.cmake due circular dependency:
888+
# USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake
889+
#
890+
if(USE_ROCM)
891+
if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)
892+
include(cmake/External/aotriton.cmake)
893+
endif()
894+
endif()
895+
886896
if(DEBUG_CUDA)
887897
string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo")
888898
string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -lineinfo")

aten/src/ATen/native/transformers/cuda/sdp_utils.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525
#include <c10/util/string_view.h>
2626

2727
#if USE_ROCM
28+
#if defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION)
2829
#include <aotriton/flash.h>
30+
#define USE_AOTRITON 1
31+
#endif
2932
#endif
3033

3134
/**
@@ -208,6 +211,7 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
208211
using sm80 = SMVersion<8, 0>;
209212
using sm90 = SMVersion<9, 0>;
210213
#if USE_ROCM
214+
#if USE_AOTRITON
211215
auto stream = at::cuda::getCurrentCUDAStream().stream();
212216
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
213217
auto dprops = at::cuda::getCurrentDeviceProperties();
@@ -217,6 +221,9 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
217221
}
218222
return false;
219223
}
224+
#else
225+
return false;
226+
#endif
220227
#else
221228
auto dprops = at::cuda::getCurrentDeviceProperties();
222229
if (!check_sm_version<sm80, sm90>(dprops)) {
@@ -239,6 +246,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
239246
using sm50 = SMVersion<5, 0>;
240247
using sm90 = SMVersion<9, 0>;
241248
#if USE_ROCM
249+
#if USE_AOTRITON
242250
auto stream = at::cuda::getCurrentCUDAStream().stream();
243251
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
244252
auto dprops = at::cuda::getCurrentDeviceProperties();
@@ -248,6 +256,9 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
248256
}
249257
return false;
250258
}
259+
#else
260+
return false;
261+
#endif
251262
#else
252263
auto dprops = at::cuda::getCurrentDeviceProperties();
253264
if (!check_sm_version<sm50, sm90>(dprops)) {

cmake/Dependencies.cmake

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,10 +1103,6 @@ if(USE_ROCM)
11031103
message(STATUS "Disabling Kernel Assert for ROCm")
11041104
endif()
11051105

1106-
include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake)
1107-
if(USE_CUDA)
1108-
caffe2_update_option(USE_MEM_EFF_ATTENTION OFF)
1109-
endif()
11101106
else()
11111107
caffe2_update_option(USE_ROCM OFF)
11121108
endif()

0 commit comments

Comments
 (0)