Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
280 changes: 228 additions & 52 deletions torch/csrc/jit/codegen/cuda/scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,11 @@ ReductionParams reductionHeuristic(
red_elems_per_thread >= kMaxValuesPerThread || !rparams.fastest_dim_) {
inputs_consumed_per_block_iter *= rparams.block_dim_y_;
red_elems_per_thread = ceilDiv(red_elems_per_thread, rparams.block_dim_y_);
rparams.cross_warp_ = true;
rparams.cross_block_ = true;
rparams.mul_reds_per_blk_ = false;
// Do multiple reductions per block
} else {
rparams.cross_warp_ = false;
rparams.cross_block_ = false;
rparams.mul_reds_per_blk_ = true;
outputs_produced_per_block_iter *= rparams.block_dim_y_;
}
Expand All @@ -320,7 +320,7 @@ ReductionParams reductionHeuristic(
rparams.grid_dim_x_ = ceilDiv(red_outputs, outputs_produced_per_block_iter);

// Cross-block reductions (if necessary)
if (rparams.cross_warp_ && red_elems_per_thread >= kMaxValuesPerThread &&
if (rparams.cross_block_ && red_elems_per_thread >= kMaxValuesPerThread &&
rparams.grid_dim_x_ <= target_grid_size) {
int blks_per_out_1 = ceilDiv(target_grid_size, rparams.grid_dim_x_);
int blks_per_out_2 = ceilDiv(red_elems_per_thread, kMinValuesPerThread);
Expand All @@ -331,7 +331,7 @@ ReductionParams reductionHeuristic(
rparams.grid_dim_y_ = std::max(1, blks_per_output);
// If a cross-block reduction was generated
if (blks_per_output > 1) {
rparams.cross_block_ = true;
rparams.cross_grid_ = true;
}
}

Expand All @@ -343,8 +343,8 @@ ReductionParams reductionHeuristic(
<< " Red On Fastest Dim? " << red_on_fastest_dim << std::endl
<< "Reduction Characteristics:" << std::endl
<< "\tMultiple Reds Per Block? " << rparams.mul_reds_per_blk_
<< " Cross Warp? " << rparams.cross_warp_ << " Cross Block? "
<< rparams.cross_block_ << std::endl
<< " Cross Warp? " << rparams.cross_block_ << " Cross Block? "
<< rparams.cross_grid_ << std::endl
<< "Recommended Blocking:" << std::endl
<< "\tGridX: " << rparams.grid_dim_x_
<< " GridY: " << rparams.grid_dim_y_
Expand Down Expand Up @@ -415,81 +415,179 @@ c10::optional<ReductionParams> scheduleReduction(
ReductionParams rparams = reductionHeuristic(
red_elems.value(), red_outputs.value(), red_on_fastest_dim);

// Heuristic Definition
constexpr int kLoopUnrollSplit = 4;

// Scheduling the Reduction
if (rparams.fastest_dim_) {
// Do multiple reductions per block
if (rparams.mul_reds_per_blk_) {
// Unroll a certain number of rFactored elements
red_tv->split(1, 4);
// Reduction Splits
// [outputs, |rF-Leftover, rf-Unroll, X-Warp|]
// Idx: 0 | 1(-1) 2(-2) 3(-1) |
// --------------------------------
// Reduction Dimensions
red_tv->split(1, rparams.block_dim_x_);
// Unroll a certain number of rFactored elements
// Split Grid dimension to get multiple reds per block
red_tv->split(1, kLoopUnrollSplit);

// Reordering the Unroll dimension eases applying computeAt()
// for preceeding operations and the rFactored Tensor.
// |- Reordered -|
// V V
// [outputs, |rF-Leftover, X-Warp, rF-Unroll|]
// Idx: 0 | 1(-3) 2(-2) 3(-1) |
// --------------------------------
// Reduction Dimensions
red_tv->reorder({{-1, -2}, {-2, -1}});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a comment would help here to explain the logic behind the reorder

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you increase the computeAt point because of the reorder?


// Output Splits
// [|Out-Leftover, Out-PerBlock|, <Reduction Dims>]
// Idx: | 0 1 | 2(-2) -- 3(-1)
// ----------------------------
// Output Dimensions
red_tv->split(0, rparams.block_dim_y_);

auto red_tv_rf = red_tv->rFactor({-3, -1});
red_tv_rf->computeAt(red_tv, 1);

// WARNING: computeAt will coalesce the rFactored dimensions
// rFactored Reduction Tensor after computeAt():
// [<output dims>, |X-Warp, rF-Leftover, rF-Unroll|]
// Idx: 0 -- 1 | 2(-3) 3(-2) 4(-1) |
// ---------------------------------
// Reduction Dimensions
red_tv_rf->computeAt(red_tv, -1);

// After the Reduction Tensor has rFactoring applied
// Reduction Output Tensor:
// [Out-Leftover, Out-PerBlock, X-Warp]
// Idx: 0 1 2(-1)

red_tv_rf->axis(-1)->parallelize(ParallelType::Unroll);

red_tv->axis(0)->parallelize(ParallelType::BIDx);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you reparallelize this? Looks like it's duplicated

red_tv->axis(1)->parallelize(ParallelType::TIDy);
red_tv->axis(-1)->parallelize(ParallelType::TIDx);

red_tv_rf->axis(1)->parallelize(ParallelType::TIDy);
red_tv_rf->axis(-2)->parallelize(ParallelType::TIDx);
red_tv_rf->axis(-1)->parallelize(ParallelType::Unroll);

// Bind Inputs to Reduction
// The computeAt is not to the inner most dimension of the rFactored
// tensor in order to force the creation of separate loop nests to cause
// Inputs to be separately read in their own loop.
// computeAt(-2)------|
// V
// [<output dims>, X-Warp, rF-Leftover,| rF-Unroll]
// Idx: 0 -- 1 2(-3) 3(-2) 4(-1)
Val* input = fusion->origin(red_tv_rf)->as<ReductionOp>()->in();
if (!fusion->hasInput(input)) {
input->as<TensorView>()->computeAt(red_tv_rf, -2);
input->as<TensorView>()->axis(-1)->parallelize(ParallelType::Unroll);
}
// Do a cross-warp reduction per block
} else {
if (rparams.cross_block_) {
red_tv->split(1, 4);
if (rparams.cross_grid_) {
// Reduction Splits
// [outputs, |rF-Leftover, rf-Unroll, X-Block, X-Grid, X-Warp|]
// Idx: 0 | 1(-5) 2(-4) 3(-3) 4(-2) 5(-1) |
// -------------------------------------------------
// Reduction Dimensions
red_tv->split(1, rparams.block_dim_x_);
// Split up rFactor to reduce across warps
red_tv->split(1, rparams.grid_dim_y_);
red_tv->split(1, rparams.block_dim_y_);
red_tv->split(1, kLoopUnrollSplit);

// Reordering the Unroll dimension eases applying computeAt()
// for preceeding operations and the rFactored Tensor.
// |------ Reordered --------|
// V V
// [outputs, |rF-Leftover, X-Warp, X-Block, X-Grid, rf-Unroll|]
// Idx: 0 | 1(-5) 2(-4) 3(-3) 4(-2) 5(-1) |
// -------------------------------------------------
// Reduction Dimensions
red_tv->reorder({{-1, -4}, {-4, -1}});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same


auto red_tv_rf = red_tv->rFactor(
{-5, -1}); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
red_tv_rf->computeAt(red_tv, 1);

red_tv->axis(0)->parallelize(ParallelType::BIDx);
// WARNING: computeAt will coalesce the rFactored dimensions
// rFactored Reduction Tensor after computeAt():
// [Outputs, |X-Warp, X-Block, X-Grid, rF-Leftover, rF-Unroll|]
// Idx: 0 | 1(-5) 2(-4) 3(-3) 4(-2) 5(-1) |
// -------------------------------------------------
// Reduction Dimensions
red_tv_rf->computeAt(red_tv, -1);

// After the Reduction Tensor has rFactoring applied
// Reduction Output Tensor:
// [Outputs, X-Warp, X-Block, X-Grid]
// Idx: 0 1(-3) 2(-2) 3(-1)

// Cross-block reduction binding
red_tv_rf->axis(-4)->parallelize(ParallelType::BIDy);
red_tv_rf->axis(-3)->parallelize(ParallelType::TIDy);
red_tv_rf->axis(-2)->parallelize(ParallelType::TIDx);
red_tv_rf->axis(-1)->parallelize(ParallelType::Unroll);

red_tv->axis(-3)->parallelize(ParallelType::BIDy);
red_tv->axis(0)->parallelize(ParallelType::BIDx);
red_tv->axis(-3)->parallelize(ParallelType::TIDx);
red_tv->axis(-2)->parallelize(ParallelType::TIDy);
red_tv->axis(-1)->parallelize(ParallelType::TIDx);
red_tv->axis(-1)->parallelize(ParallelType::BIDy);

// Bind Inputs to Reduction
// The computeAt is not to the inner most dimension of the rFactored
// tensor in order to force the creation of separate loop nests to cause
// Inputs to be separately read in their own loop.
// computeAt(-2)------|
// V
// [Outputs, X-Warp, X-Block, X-Grid, rF-Leftover,| rF-Unroll]
// Idx: 0 1(-5) 2(-4) 3(-3) 4(-2) 5(-1)
Val* input = fusion->origin(red_tv_rf)->as<ReductionOp>()->in();
if (!fusion->hasInput(input)) {
input->as<TensorView>()->computeAt(red_tv_rf, -2);
input->as<TensorView>()->axis(-1)->parallelize(ParallelType::Unroll);
}
} else {
red_tv->split(1, 4);
// Reduction Splits
// [outputs, |rF-Leftover, rf-Unroll, X-Block, X-Warp|]
// Idx: 0 | 1(-4) 2(-3) 3(-2) 4(-1) |
// -----------------------------------------
// Reduction Dimensions
red_tv->split(1, rparams.block_dim_x_);
// Split up rFactor to reduce across warps
red_tv->split(1, rparams.block_dim_y_);
red_tv->split(1, kLoopUnrollSplit);

// Reordering the Unroll dimension eases applying computeAt()
// for preceeding operations and the rFactored Tensor.
// |--- Reordered ----|
// V V
// [outputs, |rF-Leftover, X-Warp, X-Block, rF-Unroll|]
// Idx: 0 | 1(-4) 2(-3) 3(-2) 4(-1) |
// -----------------------------------------
// Reduction Dimensions
red_tv->reorder({{-1, -3}, {-3, -1}});

auto red_tv_rf = red_tv->rFactor({-4, -1});
red_tv_rf->computeAt(red_tv, 1);

red_tv->axis(0)->parallelize(ParallelType::BIDx);
// WARNING: computeAt will coalesce the rFactored dimensions
// rFactored Reduction Tensor after computeAt():
// [Outputs, |X-Warp, X-Block, rF-Leftover, rF-Unroll|]
// Idx: 0 | 1(-4) 2(-3) 3(-2) 4(-1) |
// -----------------------------------------
// Reduction Dimensions
red_tv_rf->computeAt(red_tv, -1);

// After the Reduction Tensor has rFactoring applied
// Reduction Output Tensor:
// [Outputs, X-Warp, X-Block]
// Idx: 0 1(-2) 2(-1)

red_tv_rf->axis(-3)->parallelize(ParallelType::TIDy);
red_tv_rf->axis(-2)->parallelize(ParallelType::TIDx);
red_tv_rf->axis(-1)->parallelize(ParallelType::Unroll);

red_tv->axis(-2)->parallelize(ParallelType::TIDy);
red_tv->axis(-1)->parallelize(ParallelType::TIDx);
red_tv->axis(0)->parallelize(ParallelType::BIDx);
red_tv->axis(-2)->parallelize(ParallelType::TIDx);
red_tv->axis(-1)->parallelize(ParallelType::TIDy);

// Bind Inputs to Reduction
// The computeAt is not to the inner most dimension of the rFactored
// tensor in order to force the creation of separate loop nests to cause
// Inputs to be separately read in their own loop.
// computeAt(-2)------|
// V
// [Outputs, X-Warp, X-Block, rF-Leftover,| rF-Unroll]
// Idx: 0 1(-4) 2(-3) 3(-2) 4(-1)
Val* input = fusion->origin(red_tv_rf)->as<ReductionOp>()->in();
if (!fusion->hasInput(input)) {
input->as<TensorView>()->computeAt(red_tv_rf, -2);
Expand All @@ -498,47 +596,125 @@ c10::optional<ReductionParams> scheduleReduction(
}
}
} else {
if (rparams.cross_warp_) {
if (rparams.cross_block_) {
red_tv->split(1, 4);
if (rparams.cross_block_) {
if (rparams.cross_grid_) {
// Reduction Splits
// [outputs, |rF-Leftover, rf-Unroll, X-Block, X-Grid|]
// Idx: 0 | 1(-4) 2(-3) 3(-2) 4(-1) |
// -----------------------------------------
// Reduction Dimensions
red_tv->split(1, rparams.block_dim_y_);
red_tv->split(1, rparams.grid_dim_y_);
red_tv->split(1, rparams.block_dim_y_);
red_tv->split(1, kLoopUnrollSplit);

// Reordering the Unroll dimension eases applying computeAt()
// for preceeding operations and the rFactored Tensor.
// |--- Reordered ----|
// V V
// [outputs, |rF-Leftover, X-Grid, X-Block, rF-Unroll|]
// Idx: 0 | 1(-4) 2(-3) 3(-2) 4(-1) |
// -----------------------------------------
// Reduction Dimensions
red_tv->reorder({{-1, -3}, {-3, -1}});

// Output Splits
// [|Out-Leftover, Out-PerBlock|, <Reduction Dims>]
// Idx: | 0 1 | 2(-4) -- 5(-1)
// ----------------------------
// Output Dimensions
red_tv->split(0, rparams.block_dim_x_);

auto red_tv_rf = red_tv->rFactor({-4, -1});

// Bindings
red_tv_rf->axis(1)->parallelize(ParallelType::TIDx);
red_tv_rf->axis(0)->parallelize(ParallelType::BIDx);
red_tv_rf->axis(-3)->parallelize(ParallelType::TIDy);
red_tv_rf->axis(-2)->parallelize(ParallelType::BIDy);
// WARNING: computeAt will coalesce the rFactored dimensions
// rFactored Reduction Tensor after computeAt():
// [<output dims>, |X-Grid, X-Block, rF-Leftover, rF-Unroll|]
// Idx: 0 -- 1 | 2(-4) 3(-3) 4(-2) 5(-1) |
// -----------------------------------------
// Reduction Dimensions
red_tv_rf->computeAt(red_tv, -1);

// After the Reduction Tensor has rFactoring applied
// Reduction Output Tensor:
// [Out-Leftover, Out-PerBlock, X-Grid, X-Block]
// Idx: 0 1 2(-2) 3(-1)

red_tv_rf->axis(-1)->parallelize(ParallelType::Unroll);

red_tv->axis(1)->parallelize(ParallelType::TIDx);
red_tv->axis(0)->parallelize(ParallelType::BIDx);
red_tv->axis(-1)->parallelize(ParallelType::BIDy);
red_tv->axis(-2)->parallelize(ParallelType::TIDy);
red_tv->axis(1)->parallelize(ParallelType::TIDx);
red_tv->axis(-2)->parallelize(ParallelType::BIDy);
red_tv->axis(-1)->parallelize(ParallelType::TIDy);

// Bind Inputs to Reduction
// The computeAt is not to the inner most dimension of the rFactored
// tensor in order to force the creation of separate loop nests to cause
// Inputs to be separately read in their own loop.
// computeAt(-2)------|
// V
// [<output dims>, X-Grid, X-Block, rF-Leftover,| rF-Unroll]
// Idx: 0 -- 1 2(-4) 3(-3) 4(-2) 5(-1)
Val* input = fusion->origin(red_tv_rf)->as<ReductionOp>()->in();
if (!fusion->hasInput(input)) {
input->as<TensorView>()->computeAt(red_tv_rf, -2);
input->as<TensorView>()->axis(-1)->parallelize(ParallelType::Unroll);
}
} else {
red_tv->split(1, 4);
// Reduction Splits
// [outputs, |rF-Leftover, rf-Unroll, X-Block|]
// Idx: 0 | 1(-3) 2(-2) 3(-1) |
// ---------------------------------
// Reduction Dimensions
red_tv->split(1, rparams.block_dim_y_);
red_tv->split(1, kLoopUnrollSplit);

// Reordering the Unroll dimension eases applying computeAt()
// for preceeding operations and the rFactored Tensor.
// |- Reordered -|
// V V
// [outputs, |rF-Leftover, X-Block, rF-Unroll|]
// Idx: 0 | 1(-3) 2(-2) 3(-1) |
// ---------------------------------
// Reduction Dimensions
red_tv->reorder({{-1, -2}, {-2, -1}});

// Output Splits
// [|Out-Leftover, Out-PerBlock|, <Reduction Dims>]
// Idx: | 0 1 | 2(-3) -- 4(-1)
// ----------------------------
// Output Dimensions
red_tv->split(0, rparams.block_dim_x_);

auto red_tv_rf = red_tv->rFactor({-3, -1});

// Bindings
red_tv_rf->axis(1)->parallelize(ParallelType::TIDx);
red_tv_rf->axis(0)->parallelize(ParallelType::BIDx);
red_tv_rf->axis(-2)->parallelize(ParallelType::TIDy);
// WARNING: computeAt will coalesce the rFactored dimensions
// rFactored Reduction Tensor after computeAt():
// [<output dims>, |X-Block, rF-Leftover, rF-Unroll|]
// Idx: 0 -- 1 | 2(-3) 3(-2) 4(-1) |
// ---------------------------------
// Reduction Dimensions
red_tv_rf->computeAt(red_tv, -1);

// After the Reduction Tensor has rFactoring applied
// Reduction Output Tensor:
// [Out-Leftover, Out-PerBlock, X-Block]
// Idx: 0 1 2(-1)

red_tv_rf->axis(-1)->parallelize(ParallelType::Unroll);

red_tv->axis(1)->parallelize(ParallelType::TIDx);
red_tv->axis(0)->parallelize(ParallelType::BIDx);
red_tv->axis(1)->parallelize(ParallelType::TIDx);
red_tv->axis(-1)->parallelize(ParallelType::TIDy);

// Bind Inputs to Reduction
// The computeAt is not to the inner most dimension of the rFactored
// tensor in order to force the creation of separate loop nests to cause
// Inputs to be separately read in their own loop.
// computeAt(-2)------|
// V
// [<output dims>, X-Block, rF-Leftover,| rF-Unroll]
// Idx: 0 -- 1 2(-3) 3(-2) 4(-1)
Val* input = fusion->origin(red_tv_rf)->as<ReductionOp>()->in();
if (!fusion->hasInput(input)) {
input->as<TensorView>()->computeAt(red_tv_rf, -2);
Expand Down
Loading