diff --git a/include/flashinfer/comm/trtllm_allreduce_fusion.cuh b/include/flashinfer/comm/trtllm_allreduce_fusion.cuh index dc48372cc7..2face7212e 100644 --- a/include/flashinfer/comm/trtllm_allreduce_fusion.cuh +++ b/include/flashinfer/comm/trtllm_allreduce_fusion.cuh @@ -1364,12 +1364,16 @@ cudaError_t allreduce_fusion_kernel_launcher(AllReduceFusionParams const& par threads_per_block *= 2; cluster_size /= 2; } + int sm_count = get_sm_count(); + while (cluster_num * cluster_size > sm_count && cluster_size > 1 && threads_per_block <= 512) { + threads_per_block *= 2; + cluster_size /= 2; + } FLASHINFER_CHECK(oneshot || threads_per_block >= params.nranks, "not oneshot, or threads_per_block < nranks"); int block_size = threads_per_block; FLASHINFER_CHECK(block_size <= 1024 && cluster_size > 0, "block_size > 1024 or cluster_size <= 0"); - int sm_count = get_sm_count(); int grid_size = (std::min(sm_count, cluster_num * cluster_size) / cluster_size) * cluster_size; cudaLaunchConfig_t cfg; cudaLaunchAttribute attribute[2];