diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index 5f9209979341..5766fbab4e87 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -7,7 +7,11 @@ #include #include -#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +#ifdef USE_ROCM + #include // For C10_HIP_CHECK and C10_HIP_KERNEL_LAUNCH_CHECK +#else + #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +#endif #ifndef USE_ROCM #include @@ -320,8 +324,13 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { dim3 grid(params.batch, params.dim / kNRows); auto kernel = &selective_scan_fwd_kernel; if (kSmemSize >= 48 * 1024) { +#ifdef USE_ROCM + C10_HIP_CHECK(hipFuncSetAttribute( + reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); +#else C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); +#endif } kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK();