@@ -349,14 +349,17 @@ ReductionParams reductionHeuristic(
349
349
inputs_consumed_per_block_iter *= rparams.block_dim_y_ ;
350
350
red_elems_per_thread = ceilDiv (red_elems_per_thread, rparams.block_dim_y_ );
351
351
rparams.cross_warp_ = true ;
352
+ rparams.mul_reds_per_blk_ = false ;
352
353
// Do multiple reductions per block
353
354
} else {
355
+ rparams.cross_warp_ = false ;
354
356
rparams.mul_reds_per_blk_ = true ;
355
357
outputs_produced_per_block_iter *= rparams.block_dim_y_ ;
356
358
}
357
359
358
360
// 5. Distributing work across blocks
359
361
362
+ // WARNING: Current device for codegen may not be the target device
360
363
int device_max_threads_per_multiprocessor =
361
364
at::cuda::getCurrentDeviceProperties ()->maxThreadsPerMultiProcessor ;
362
365
int device_multiprocessor_count =
@@ -402,10 +405,13 @@ bool scheduleReduction(Fusion* fusion, const at::ArrayRef<c10::IValue> inputs) {
402
405
// 2D at this point to make the issue easier, right now.
403
406
404
407
// Find Reduction TensorView
408
+ // TODO: This is making an assumption there is only one reduction
409
+ // in a kernel. This will not be true in the long run.
405
410
TensorView* red_tv = nullptr ;
406
411
for (auto & expr : fusion->exprs (/* from_outputs_only*/ true )) {
407
412
if (expr->type () == ExprType::ReductionOp) {
408
413
red_tv = static_cast <TensorView*>(expr->output (0 ));
414
+ break ;
409
415
}
410
416
}
411
417
if (red_tv == nullptr ) { // No reduction found
0 commit comments