File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -1364,12 +1364,16 @@ cudaError_t allreduce_fusion_kernel_launcher(AllReduceFusionParams<T> const& par
1364
1364
threads_per_block *= 2 ;
1365
1365
cluster_size /= 2 ;
1366
1366
}
1367
+ int sm_count = get_sm_count ();
1368
+ while (cluster_num * cluster_size > sm_count && cluster_size > 1 && threads_per_block <= 512 ) {
1369
+ threads_per_block *= 2 ;
1370
+ cluster_size /= 2 ;
1371
+ }
1367
1372
FLASHINFER_CHECK (oneshot || threads_per_block >= params.nranks ,
1368
1373
" not oneshot, or threads_per_block < nranks" );
1369
1374
int block_size = threads_per_block;
1370
1375
FLASHINFER_CHECK (block_size <= 1024 && cluster_size > 0 ,
1371
1376
" block_size > 1024 or cluster_size <= 0" );
1372
- int sm_count = get_sm_count ();
1373
1377
int grid_size = (std::min (sm_count, cluster_num * cluster_size) / cluster_size) * cluster_size;
1374
1378
cudaLaunchConfig_t cfg;
1375
1379
cudaLaunchAttribute attribute[2 ];
You can’t perform that action at this time.
0 commit comments