Skip to content

Commit 9f72899

Browse files
alugoreydnikolaev-amd
authored andcommitted
Extend CK gemm/sdpa support to gfx950 (#45)
(cherry picked from commit e17274f35859c5097132fa096389623e4af89e26)
1 parent 5a7dfff commit 9f72899

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

aten/src/ATen/Context.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) {
416416
if(b == at::ROCmFABackend::Ck) {
417417
static const bool ck_unsupported = []() {
418418
static const std::vector<std::string> archs = {
419-
"gfx90a", "gfx942"
419+
"gfx90a", "gfx942", "gfx950"
420420
};
421421
for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) {
422422
if (!detail::getCUDAHooks().isGPUArch(archs, index)) {

aten/src/ATen/native/transformers/hip/flash_attn/ck/launch_kernel_pt.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ __launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
1414
#endif
1515
__global__ void kentry_pt(Args... args)
1616
{
17-
#if (defined(__gfx90a__) || defined(__gfx942__))
17+
#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__))
1818
Kernel{}(args...);
1919
#else
2020
CUDA_KERNEL_ASSERT(false && "Fatal! Attempting to call a CK SDPA kernel on unsupported hardware");

0 commit comments

Comments
 (0)