Skip to content

Commit 3b9ecc1

Browse files
committed
update allreduce to match trtllm
1 parent 20ab8ab commit 3b9ecc1

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

include/flashinfer/comm/trtllm_allreduce_fusion.cuh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1364,12 +1364,16 @@ cudaError_t allreduce_fusion_kernel_launcher(AllReduceFusionParams<T> const& par
13641364
threads_per_block *= 2;
13651365
cluster_size /= 2;
13661366
}
1367+
int sm_count = get_sm_count();
1368+
while (cluster_num * cluster_size > sm_count && cluster_size > 1 && threads_per_block <= 512) {
1369+
threads_per_block *= 2;
1370+
cluster_size /= 2;
1371+
}
13671372
FLASHINFER_CHECK(oneshot || threads_per_block >= params.nranks,
13681373
"not oneshot, or threads_per_block < nranks");
13691374
int block_size = threads_per_block;
13701375
FLASHINFER_CHECK(block_size <= 1024 && cluster_size > 0,
13711376
"block_size > 1024 or cluster_size <= 0");
1372-
int sm_count = get_sm_count();
13731377
int grid_size = (std::min(sm_count, cluster_num * cluster_size) / cluster_size) * cluster_size;
13741378
cudaLaunchConfig_t cfg;
13751379
cudaLaunchAttribute attribute[2];

0 commit comments

Comments
 (0)