Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 39 additions & 47 deletions torch/csrc/jit/codegen/cuda/scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,22 +248,6 @@ constexpr int lastPow2(int n) {
return std::max(1, n - (n >> 1));
}

// Parameters the Reduction Heuristic Generates to describe
// the optimial schedule
struct ReductionParams {
// Reduction Blocking
int grid_dim_x_ = 1;
int grid_dim_y_ = 1;
int block_dim_x_ = 1;
int block_dim_y_ = 1;

// Reduction Attributes
bool fastest_dim_ = true;
bool cross_warp_ = false;
bool cross_block_ = false;
bool mul_reds_per_blk_ = false;
};

ReductionParams reductionHeuristic(
int outer_dim,
int inner_dim,
Expand Down Expand Up @@ -385,7 +369,9 @@ ReductionParams reductionHeuristic(
} // anonymous namespace

// fusion is the input IR that will be modified by this function
bool scheduleReduction(Fusion* fusion, const at::ArrayRef<c10::IValue> inputs) {
c10::optional<ReductionParams> scheduleReduction(
Fusion* fusion,
const at::ArrayRef<c10::IValue>& inputs) {
FusionGuard fg(fusion);

// TODO: I am making a larger initial assumption that reductions are
Expand All @@ -395,14 +381,16 @@ bool scheduleReduction(Fusion* fusion, const at::ArrayRef<c10::IValue> inputs) {
// TODO: This is making an assumption there is only one reduction
// in a kernel. This will not be true in the long run.
TensorView* red_tv = nullptr;
DataType red_dtype = DataType::Null;
for (auto& expr : fusion->exprs(/*from_outputs_only*/ true)) {
if (expr->type() == ExprType::ReductionOp) {
red_tv = static_cast<TensorView*>(expr->output(0));
red_dtype = expr->input(0)->getDataType().value();
break;
}
}
if (red_tv == nullptr) { // No reduction found
return false;
return c10::nullopt;
}

EvaluationContext eval_context(fusion);
Expand Down Expand Up @@ -475,66 +463,70 @@ bool scheduleReduction(Fusion* fusion, const at::ArrayRef<c10::IValue> inputs) {
if (rparams.cross_block_) {
red_tv->split(-1, rparams.block_dim_x_);
// Split up rFactor to reduce across warps
red_tv->split(-2, rparams.block_dim_y_);
red_tv->split(-3, rparams.grid_dim_y_);
red_tv->split(0, rparams.grid_dim_y_);
red_tv->split(0, rparams.block_dim_y_);

auto red_tv_rf = red_tv->rFactor({-4});
red_tv_rf->computeAt(red_tv, 1);

red_tv->axis(0)->parallelize(ParallelType::BIDx);

// Cross-block reduction binding
red_tv_rf->axis(-3)->parallelize(ParallelType::BIDy);
red_tv->axis(-3)->parallelize(ParallelType::BIDy);

red_tv_rf->axis(-3)->parallelize(ParallelType::TIDy);
red_tv_rf->axis(-2)->parallelize(ParallelType::BIDy);
red_tv_rf->axis(-1)->parallelize(ParallelType::TIDx);
red_tv->axis(-1)->parallelize(ParallelType::TIDx);

red_tv_rf->axis(-2)->parallelize(ParallelType::TIDy);
red_tv->axis(-2)->parallelize(ParallelType::TIDy);
red_tv->axis(-3)->parallelize(ParallelType::TIDy);
red_tv->axis(-2)->parallelize(ParallelType::BIDy);
red_tv->axis(-1)->parallelize(ParallelType::TIDx);

} else {
red_tv->split(-1, rparams.block_dim_x_);
// Split up rFactor to reduce across warps
red_tv->split(-2, rparams.block_dim_y_);
red_tv->split(0, rparams.block_dim_y_);

auto red_tv_rf = red_tv->rFactor({-3});
red_tv_rf->computeAt(red_tv, 1);

red_tv->axis(0)->parallelize(ParallelType::BIDx);

red_tv_rf->axis(-2)->parallelize(ParallelType::TIDy);
red_tv_rf->axis(-1)->parallelize(ParallelType::TIDx);
red_tv->axis(-1)->parallelize(ParallelType::TIDx);

red_tv_rf->axis(-2)->parallelize(ParallelType::TIDy);
red_tv->axis(-2)->parallelize(ParallelType::TIDy);
red_tv->axis(-1)->parallelize(ParallelType::TIDx);
}
}
} else {
// TODO: This block param needs to be replaced by a
// a proper attribute when I determine the proper one
if (rparams.block_dim_y_ > 1) {
red_tv->split(-1, rparams.block_dim_x_);
if (rparams.cross_warp_) {
if (rparams.cross_block_) {
red_tv->split(-1, rparams.block_dim_x_);
red_tv->split(0, rparams.grid_dim_y_);
}
red_tv->split(0, rparams.block_dim_y_);
auto red_tv_rf = red_tv->rFactor({0});
red_tv_rf->axis(-1)->parallelize(ParallelType::TIDx);
red_tv_rf->axis(-2)->parallelize(ParallelType::BIDx);
if (rparams.cross_block_) {
red_tv->split(0, rparams.block_dim_y_);
auto red_tv_rf = red_tv->rFactor({0});

// Bindings
red_tv_rf->axis(-1)->parallelize(ParallelType::TIDx);
red_tv_rf->axis(-2)->parallelize(ParallelType::BIDx);
red_tv_rf->axis(-3)->parallelize(ParallelType::BIDy);
red_tv_rf->axis(-4)->parallelize(ParallelType::TIDy);
} else {
red_tv_rf->axis(-3)->parallelize(ParallelType::TIDy);
}
red_tv->axis(-1)->parallelize(ParallelType::TIDx);
red_tv->axis(-2)->parallelize(ParallelType::BIDx);
red_tv->axis(-3)->parallelize(ParallelType::TIDy);
if (rparams.cross_block_) {

red_tv->axis(-1)->parallelize(ParallelType::TIDx);
red_tv->axis(-2)->parallelize(ParallelType::BIDx);
red_tv->axis(-3)->parallelize(ParallelType::BIDy);
red_tv->axis(-4)->parallelize(ParallelType::TIDy);
} else {
red_tv->split(-1, rparams.block_dim_x_);
red_tv->split(0, rparams.block_dim_y_);
auto red_tv_rf = red_tv->rFactor({0});

// Bindings
red_tv_rf->axis(-1)->parallelize(ParallelType::TIDx);
red_tv_rf->axis(-2)->parallelize(ParallelType::BIDx);
red_tv_rf->axis(-3)->parallelize(ParallelType::TIDy);

red_tv->axis(-1)->parallelize(ParallelType::TIDx);
red_tv->axis(-2)->parallelize(ParallelType::BIDx);
red_tv->axis(-3)->parallelize(ParallelType::TIDy);
}
} else {
Expand All @@ -558,7 +550,7 @@ bool scheduleReduction(Fusion* fusion, const at::ArrayRef<c10::IValue> inputs) {
fusion->setLaunchConfig(LaunchConfigType::SharedMemory, new Int(0));
fusion->setLaunchConfig(LaunchConfigType::Compatible, new Int(1));

return true;
return rparams;
}

} // namespace cuda
Expand Down
20 changes: 18 additions & 2 deletions torch/csrc/jit/codegen/cuda/scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,28 @@ TORCH_CUDA_API bool scheduleFusion(
Fusion* fusion,
const at::ArrayRef<c10::IValue> inputs);

// Parameters the Reduction Heuristic Generates to describe
// the optimial schedule
struct ReductionParams {
// Reduction Blocking
int grid_dim_x_ = 1;
int grid_dim_y_ = 1;
int block_dim_x_ = 1;
int block_dim_y_ = 1;

// Reduction Attributes
bool fastest_dim_ = true;
bool cross_warp_ = false;
bool cross_block_ = false;
bool mul_reds_per_blk_ = false;
};

// TODO: This function is currently a redundant API as I populate a more
// substantial reduction heuristic
// fusion is the input IR that will be modified by this function
TORCH_CUDA_API bool scheduleReduction(
TORCH_CUDA_API c10::optional<ReductionParams> scheduleReduction(
Fusion* fusion,
const at::ArrayRef<c10::IValue> inputs);
const at::ArrayRef<c10::IValue>& inputs);

} // namespace cuda
} // namespace fuser
Expand Down