Skip to content

Commit 1a5a7da

Browse files
alugoreypruthvistony
authored andcommitted
Extend CK gemm/sdpa support to gfx950 (#45)
Update CK for gfx950 (#49)
1 parent 3d6ba22 commit 1a5a7da

File tree

4 files changed

+4
-3
lines changed

4 files changed

+4
-3
lines changed

aten/src/ATen/Context.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) {
416416
if(b == at::ROCmFABackend::Ck) {
417417
static const bool ck_unsupported = []() {
418418
static const std::vector<std::string> archs = {
419-
"gfx90a", "gfx942"
419+
"gfx90a", "gfx942", "gfx950"
420420
};
421421
for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) {
422422
if (!detail::getCUDAHooks().isGPUArch(archs, index)) {

aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,4 +453,5 @@ struct fmha_bwd_traits
453453
bool is_deterministic;
454454
// TODO: padding check is inside this api
455455
};
456+
template <int Version = 2>
456457
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);

aten/src/ATen/native/transformers/hip/flash_attn/ck/launch_kernel_pt.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ __launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
1414
#endif
1515
__global__ void kentry_pt(Args... args)
1616
{
17-
#if (defined(__gfx90a__) || defined(__gfx942__))
17+
#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__))
1818
Kernel{}(args...);
1919
#else
2020
CUDA_KERNEL_ASSERT(false && "Fatal! Attempting to call a CK SDPA kernel on unsupported hardware");

third_party/composable_kernel

Submodule composable_kernel updated 836 files

0 commit comments

Comments
 (0)