Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion csrc/mamba/mamba_ssm/selective_scan_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@

#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
#ifdef USE_ROCM
#include <c10/hip/HIPException.h> // For C10_HIP_CHECK and C10_HIP_KERNEL_LAUNCH_CHECK
#else
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
#endif

#ifndef USE_ROCM
#include <cub/block/block_load.cuh>
Expand Down Expand Up @@ -320,8 +324,13 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
dim3 grid(params.batch, params.dim / kNRows);
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
if (kSmemSize >= 48 * 1024) {
#ifdef USE_ROCM
C10_HIP_CHECK(hipFuncSetAttribute(
reinterpret_cast<const void*>(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
#else
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
#endif
Comment on lines +327 to +333
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

While you've correctly handled cudaFuncSetAttribute with a conditional for ROCm, the C10_CUDA_KERNEL_LAUNCH_CHECK() on line 336 is still CUDA-specific and will cause a compilation error on ROCm builds. This should also be wrapped in an #ifdef USE_ROCM block to call C10_HIP_KERNEL_LAUNCH_CHECK() for ROCm.

}
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand Down