diff --git a/torch/csrc/jit/codegen/cuda/scheduler.cpp b/torch/csrc/jit/codegen/cuda/scheduler.cpp index 0cfa69dacb16..89266c306638 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler.cpp @@ -295,11 +295,11 @@ ReductionParams reductionHeuristic( red_elems_per_thread >= kMaxValuesPerThread || !rparams.fastest_dim_) { inputs_consumed_per_block_iter *= rparams.block_dim_y_; red_elems_per_thread = ceilDiv(red_elems_per_thread, rparams.block_dim_y_); - rparams.cross_warp_ = true; + rparams.cross_block_ = true; rparams.mul_reds_per_blk_ = false; // Do multiple reductions per block } else { - rparams.cross_warp_ = false; + rparams.cross_block_ = false; rparams.mul_reds_per_blk_ = true; outputs_produced_per_block_iter *= rparams.block_dim_y_; } @@ -320,7 +320,7 @@ ReductionParams reductionHeuristic( rparams.grid_dim_x_ = ceilDiv(red_outputs, outputs_produced_per_block_iter); // Cross-block reductions (if necessary) - if (rparams.cross_warp_ && red_elems_per_thread >= kMaxValuesPerThread && + if (rparams.cross_block_ && red_elems_per_thread >= kMaxValuesPerThread && rparams.grid_dim_x_ <= target_grid_size) { int blks_per_out_1 = ceilDiv(target_grid_size, rparams.grid_dim_x_); int blks_per_out_2 = ceilDiv(red_elems_per_thread, kMinValuesPerThread); @@ -331,7 +331,7 @@ ReductionParams reductionHeuristic( rparams.grid_dim_y_ = std::max(1, blks_per_output); // If a cross-block reduction was generated if (blks_per_output > 1) { - rparams.cross_block_ = true; + rparams.cross_grid_ = true; } } @@ -343,8 +343,8 @@ ReductionParams reductionHeuristic( << " Red On Fastest Dim? " << red_on_fastest_dim << std::endl << "Reduction Characteristics:" << std::endl << "\tMultiple Reds Per Block? " << rparams.mul_reds_per_blk_ - << " Cross Warp? " << rparams.cross_warp_ << " Cross Block? " - << rparams.cross_block_ << std::endl + << " Cross Warp? " << rparams.cross_block_ << " Cross Block? " + << rparams.cross_grid_ << std::endl << "Recommended Blocking:" << std::endl << "\tGridX: " << rparams.grid_dim_x_ << " GridY: " << rparams.grid_dim_y_ @@ -415,28 +415,66 @@ c10::optional scheduleReduction( ReductionParams rparams = reductionHeuristic( red_elems.value(), red_outputs.value(), red_on_fastest_dim); - // Heuristic Definition + constexpr int kLoopUnrollSplit = 4; + + // Scheduling the Reduction if (rparams.fastest_dim_) { // Do multiple reductions per block if (rparams.mul_reds_per_blk_) { - // Unroll a certain number of rFactored elements - red_tv->split(1, 4); + // Reduction Splits + // [outputs, |rF-Leftover, rf-Unroll, X-Warp|] + // Idx: 0 | 1(-1) 2(-2) 3(-1) | + // -------------------------------- + // Reduction Dimensions red_tv->split(1, rparams.block_dim_x_); - // Unroll a certain number of rFactored elements - // Split Grid dimension to get multiple reds per block + red_tv->split(1, kLoopUnrollSplit); + + // Reordering the Unroll dimension eases applying computeAt() + // for preceeding operations and the rFactored Tensor. + // |- Reordered -| + // V V + // [outputs, |rF-Leftover, X-Warp, rF-Unroll|] + // Idx: 0 | 1(-3) 2(-2) 3(-1) | + // -------------------------------- + // Reduction Dimensions + red_tv->reorder({{-1, -2}, {-2, -1}}); + + // Output Splits + // [|Out-Leftover, Out-PerBlock|, ] + // Idx: | 0 1 | 2(-2) -- 3(-1) + // ---------------------------- + // Output Dimensions red_tv->split(0, rparams.block_dim_y_); auto red_tv_rf = red_tv->rFactor({-3, -1}); - red_tv_rf->computeAt(red_tv, 1); + + // WARNING: computeAt will coalesce the rFactored dimensions + // rFactored Reduction Tensor after computeAt(): + // [, |X-Warp, rF-Leftover, rF-Unroll|] + // Idx: 0 -- 1 | 2(-3) 3(-2) 4(-1) | + // --------------------------------- + // Reduction Dimensions + red_tv_rf->computeAt(red_tv, -1); + + // After the Reduction Tensor has rFactoring applied + // Reduction Output Tensor: + // [Out-Leftover, Out-PerBlock, X-Warp] + // Idx: 0 1 2(-1) + + red_tv_rf->axis(-1)->parallelize(ParallelType::Unroll); red_tv->axis(0)->parallelize(ParallelType::BIDx); red_tv->axis(1)->parallelize(ParallelType::TIDy); red_tv->axis(-1)->parallelize(ParallelType::TIDx); - red_tv_rf->axis(1)->parallelize(ParallelType::TIDy); - red_tv_rf->axis(-2)->parallelize(ParallelType::TIDx); - red_tv_rf->axis(-1)->parallelize(ParallelType::Unroll); - + // Bind Inputs to Reduction + // The computeAt is not to the inner most dimension of the rFactored + // tensor in order to force the creation of separate loop nests to cause + // Inputs to be separately read in their own loop. + // computeAt(-2)------| + // V + // [, X-Warp, rF-Leftover,| rF-Unroll] + // Idx: 0 -- 1 2(-3) 3(-2) 4(-1) Val* input = fusion->origin(red_tv_rf)->as()->in(); if (!fusion->hasInput(input)) { input->as()->computeAt(red_tv_rf, -2); @@ -444,52 +482,112 @@ c10::optional scheduleReduction( } // Do a cross-warp reduction per block } else { - if (rparams.cross_block_) { - red_tv->split(1, 4); + if (rparams.cross_grid_) { + // Reduction Splits + // [outputs, |rF-Leftover, rf-Unroll, X-Block, X-Grid, X-Warp|] + // Idx: 0 | 1(-5) 2(-4) 3(-3) 4(-2) 5(-1) | + // ------------------------------------------------- + // Reduction Dimensions red_tv->split(1, rparams.block_dim_x_); - // Split up rFactor to reduce across warps red_tv->split(1, rparams.grid_dim_y_); red_tv->split(1, rparams.block_dim_y_); + red_tv->split(1, kLoopUnrollSplit); + + // Reordering the Unroll dimension eases applying computeAt() + // for preceeding operations and the rFactored Tensor. + // |------ Reordered --------| + // V V + // [outputs, |rF-Leftover, X-Warp, X-Block, X-Grid, rf-Unroll|] + // Idx: 0 | 1(-5) 2(-4) 3(-3) 4(-2) 5(-1) | + // ------------------------------------------------- + // Reduction Dimensions + red_tv->reorder({{-1, -4}, {-4, -1}}); auto red_tv_rf = red_tv->rFactor( {-5, -1}); // NOLINT(cppcoreguidelines-avoid-magic-numbers) - red_tv_rf->computeAt(red_tv, 1); - red_tv->axis(0)->parallelize(ParallelType::BIDx); + // WARNING: computeAt will coalesce the rFactored dimensions + // rFactored Reduction Tensor after computeAt(): + // [Outputs, |X-Warp, X-Block, X-Grid, rF-Leftover, rF-Unroll|] + // Idx: 0 | 1(-5) 2(-4) 3(-3) 4(-2) 5(-1) | + // ------------------------------------------------- + // Reduction Dimensions + red_tv_rf->computeAt(red_tv, -1); + + // After the Reduction Tensor has rFactoring applied + // Reduction Output Tensor: + // [Outputs, X-Warp, X-Block, X-Grid] + // Idx: 0 1(-3) 2(-2) 3(-1) - // Cross-block reduction binding - red_tv_rf->axis(-4)->parallelize(ParallelType::BIDy); - red_tv_rf->axis(-3)->parallelize(ParallelType::TIDy); - red_tv_rf->axis(-2)->parallelize(ParallelType::TIDx); red_tv_rf->axis(-1)->parallelize(ParallelType::Unroll); - red_tv->axis(-3)->parallelize(ParallelType::BIDy); + red_tv->axis(0)->parallelize(ParallelType::BIDx); + red_tv->axis(-3)->parallelize(ParallelType::TIDx); red_tv->axis(-2)->parallelize(ParallelType::TIDy); - red_tv->axis(-1)->parallelize(ParallelType::TIDx); + red_tv->axis(-1)->parallelize(ParallelType::BIDy); + // Bind Inputs to Reduction + // The computeAt is not to the inner most dimension of the rFactored + // tensor in order to force the creation of separate loop nests to cause + // Inputs to be separately read in their own loop. + // computeAt(-2)------| + // V + // [Outputs, X-Warp, X-Block, X-Grid, rF-Leftover,| rF-Unroll] + // Idx: 0 1(-5) 2(-4) 3(-3) 4(-2) 5(-1) Val* input = fusion->origin(red_tv_rf)->as()->in(); if (!fusion->hasInput(input)) { input->as()->computeAt(red_tv_rf, -2); input->as()->axis(-1)->parallelize(ParallelType::Unroll); } } else { - red_tv->split(1, 4); + // Reduction Splits + // [outputs, |rF-Leftover, rf-Unroll, X-Block, X-Warp|] + // Idx: 0 | 1(-4) 2(-3) 3(-2) 4(-1) | + // ----------------------------------------- + // Reduction Dimensions red_tv->split(1, rparams.block_dim_x_); - // Split up rFactor to reduce across warps red_tv->split(1, rparams.block_dim_y_); + red_tv->split(1, kLoopUnrollSplit); + + // Reordering the Unroll dimension eases applying computeAt() + // for preceeding operations and the rFactored Tensor. + // |--- Reordered ----| + // V V + // [outputs, |rF-Leftover, X-Warp, X-Block, rF-Unroll|] + // Idx: 0 | 1(-4) 2(-3) 3(-2) 4(-1) | + // ----------------------------------------- + // Reduction Dimensions + red_tv->reorder({{-1, -3}, {-3, -1}}); auto red_tv_rf = red_tv->rFactor({-4, -1}); - red_tv_rf->computeAt(red_tv, 1); - red_tv->axis(0)->parallelize(ParallelType::BIDx); + // WARNING: computeAt will coalesce the rFactored dimensions + // rFactored Reduction Tensor after computeAt(): + // [Outputs, |X-Warp, X-Block, rF-Leftover, rF-Unroll|] + // Idx: 0 | 1(-4) 2(-3) 3(-2) 4(-1) | + // ----------------------------------------- + // Reduction Dimensions + red_tv_rf->computeAt(red_tv, -1); + + // After the Reduction Tensor has rFactoring applied + // Reduction Output Tensor: + // [Outputs, X-Warp, X-Block] + // Idx: 0 1(-2) 2(-1) - red_tv_rf->axis(-3)->parallelize(ParallelType::TIDy); - red_tv_rf->axis(-2)->parallelize(ParallelType::TIDx); red_tv_rf->axis(-1)->parallelize(ParallelType::Unroll); - red_tv->axis(-2)->parallelize(ParallelType::TIDy); - red_tv->axis(-1)->parallelize(ParallelType::TIDx); + red_tv->axis(0)->parallelize(ParallelType::BIDx); + red_tv->axis(-2)->parallelize(ParallelType::TIDx); + red_tv->axis(-1)->parallelize(ParallelType::TIDy); + // Bind Inputs to Reduction + // The computeAt is not to the inner most dimension of the rFactored + // tensor in order to force the creation of separate loop nests to cause + // Inputs to be separately read in their own loop. + // computeAt(-2)------| + // V + // [Outputs, X-Warp, X-Block, rF-Leftover,| rF-Unroll] + // Idx: 0 1(-4) 2(-3) 3(-2) 4(-1) Val* input = fusion->origin(red_tv_rf)->as()->in(); if (!fusion->hasInput(input)) { input->as()->computeAt(red_tv_rf, -2); @@ -498,47 +596,125 @@ c10::optional scheduleReduction( } } } else { - if (rparams.cross_warp_) { - if (rparams.cross_block_) { - red_tv->split(1, 4); + if (rparams.cross_block_) { + if (rparams.cross_grid_) { + // Reduction Splits + // [outputs, |rF-Leftover, rf-Unroll, X-Block, X-Grid|] + // Idx: 0 | 1(-4) 2(-3) 3(-2) 4(-1) | + // ----------------------------------------- + // Reduction Dimensions + red_tv->split(1, rparams.block_dim_y_); red_tv->split(1, rparams.grid_dim_y_); red_tv->split(1, rparams.block_dim_y_); + red_tv->split(1, kLoopUnrollSplit); + + // Reordering the Unroll dimension eases applying computeAt() + // for preceeding operations and the rFactored Tensor. + // |--- Reordered ----| + // V V + // [outputs, |rF-Leftover, X-Grid, X-Block, rF-Unroll|] + // Idx: 0 | 1(-4) 2(-3) 3(-2) 4(-1) | + // ----------------------------------------- + // Reduction Dimensions + red_tv->reorder({{-1, -3}, {-3, -1}}); + + // Output Splits + // [|Out-Leftover, Out-PerBlock|, ] + // Idx: | 0 1 | 2(-4) -- 5(-1) + // ---------------------------- + // Output Dimensions red_tv->split(0, rparams.block_dim_x_); + auto red_tv_rf = red_tv->rFactor({-4, -1}); - // Bindings - red_tv_rf->axis(1)->parallelize(ParallelType::TIDx); - red_tv_rf->axis(0)->parallelize(ParallelType::BIDx); - red_tv_rf->axis(-3)->parallelize(ParallelType::TIDy); - red_tv_rf->axis(-2)->parallelize(ParallelType::BIDy); + // WARNING: computeAt will coalesce the rFactored dimensions + // rFactored Reduction Tensor after computeAt(): + // [, |X-Grid, X-Block, rF-Leftover, rF-Unroll|] + // Idx: 0 -- 1 | 2(-4) 3(-3) 4(-2) 5(-1) | + // ----------------------------------------- + // Reduction Dimensions + red_tv_rf->computeAt(red_tv, -1); + + // After the Reduction Tensor has rFactoring applied + // Reduction Output Tensor: + // [Out-Leftover, Out-PerBlock, X-Grid, X-Block] + // Idx: 0 1 2(-2) 3(-1) + red_tv_rf->axis(-1)->parallelize(ParallelType::Unroll); - red_tv->axis(1)->parallelize(ParallelType::TIDx); red_tv->axis(0)->parallelize(ParallelType::BIDx); - red_tv->axis(-1)->parallelize(ParallelType::BIDy); - red_tv->axis(-2)->parallelize(ParallelType::TIDy); + red_tv->axis(1)->parallelize(ParallelType::TIDx); + red_tv->axis(-2)->parallelize(ParallelType::BIDy); + red_tv->axis(-1)->parallelize(ParallelType::TIDy); + // Bind Inputs to Reduction + // The computeAt is not to the inner most dimension of the rFactored + // tensor in order to force the creation of separate loop nests to cause + // Inputs to be separately read in their own loop. + // computeAt(-2)------| + // V + // [, X-Grid, X-Block, rF-Leftover,| rF-Unroll] + // Idx: 0 -- 1 2(-4) 3(-3) 4(-2) 5(-1) Val* input = fusion->origin(red_tv_rf)->as()->in(); if (!fusion->hasInput(input)) { input->as()->computeAt(red_tv_rf, -2); input->as()->axis(-1)->parallelize(ParallelType::Unroll); } } else { - red_tv->split(1, 4); + // Reduction Splits + // [outputs, |rF-Leftover, rf-Unroll, X-Block|] + // Idx: 0 | 1(-3) 2(-2) 3(-1) | + // --------------------------------- + // Reduction Dimensions red_tv->split(1, rparams.block_dim_y_); + red_tv->split(1, kLoopUnrollSplit); + + // Reordering the Unroll dimension eases applying computeAt() + // for preceeding operations and the rFactored Tensor. + // |- Reordered -| + // V V + // [outputs, |rF-Leftover, X-Block, rF-Unroll|] + // Idx: 0 | 1(-3) 2(-2) 3(-1) | + // --------------------------------- + // Reduction Dimensions + red_tv->reorder({{-1, -2}, {-2, -1}}); + + // Output Splits + // [|Out-Leftover, Out-PerBlock|, ] + // Idx: | 0 1 | 2(-3) -- 4(-1) + // ---------------------------- + // Output Dimensions red_tv->split(0, rparams.block_dim_x_); + auto red_tv_rf = red_tv->rFactor({-3, -1}); - // Bindings - red_tv_rf->axis(1)->parallelize(ParallelType::TIDx); - red_tv_rf->axis(0)->parallelize(ParallelType::BIDx); - red_tv_rf->axis(-2)->parallelize(ParallelType::TIDy); + // WARNING: computeAt will coalesce the rFactored dimensions + // rFactored Reduction Tensor after computeAt(): + // [, |X-Block, rF-Leftover, rF-Unroll|] + // Idx: 0 -- 1 | 2(-3) 3(-2) 4(-1) | + // --------------------------------- + // Reduction Dimensions + red_tv_rf->computeAt(red_tv, -1); + + // After the Reduction Tensor has rFactoring applied + // Reduction Output Tensor: + // [Out-Leftover, Out-PerBlock, X-Block] + // Idx: 0 1 2(-1) + red_tv_rf->axis(-1)->parallelize(ParallelType::Unroll); - red_tv->axis(1)->parallelize(ParallelType::TIDx); red_tv->axis(0)->parallelize(ParallelType::BIDx); + red_tv->axis(1)->parallelize(ParallelType::TIDx); red_tv->axis(-1)->parallelize(ParallelType::TIDy); + // Bind Inputs to Reduction + // The computeAt is not to the inner most dimension of the rFactored + // tensor in order to force the creation of separate loop nests to cause + // Inputs to be separately read in their own loop. + // computeAt(-2)------| + // V + // [, X-Block, rF-Leftover,| rF-Unroll] + // Idx: 0 -- 1 2(-3) 3(-2) 4(-1) Val* input = fusion->origin(red_tv_rf)->as()->in(); if (!fusion->hasInput(input)) { input->as()->computeAt(red_tv_rf, -2); diff --git a/torch/csrc/jit/codegen/cuda/scheduler.h b/torch/csrc/jit/codegen/cuda/scheduler.h index 547e6bc79af4..9c67de52d246 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler.h +++ b/torch/csrc/jit/codegen/cuda/scheduler.h @@ -24,8 +24,8 @@ struct ReductionParams { // Reduction Attributes bool fastest_dim_ = true; - bool cross_warp_ = false; bool cross_block_ = false; + bool cross_grid_ = false; bool mul_reds_per_blk_ = false; };