From 7450ba9cc34d899c542a9a11c7a6573393235a0a Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 3 Oct 2022 00:28:50 -0700 Subject: [PATCH 01/15] Add shallowCopy --- torch/csrc/jit/codegen/cuda/ir_base_nodes.h | 11 +- .../csrc/jit/codegen/cuda/ir_internal_nodes.h | 46 ++++ torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 212 +++++++++++++----- torch/csrc/jit/codegen/cuda/kernel_ir.cpp | 133 ++++++++++- torch/csrc/jit/codegen/cuda/kernel_ir.h | 32 +++ torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 3 + 6 files changed, 376 insertions(+), 61 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index 7d5ebad25282..ad3e010f5362 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -426,6 +426,8 @@ class TORCH_CUDA_CU_API Expr : public Statement { Expr(const Expr* src, IrCloner* ir_cloner); + virtual Expr *shallowCopy() const = 0; + c10::optional getExprType() const override { return etype_; } @@ -467,15 +469,22 @@ class TORCH_CUDA_CU_API Expr : public Statement { kir::Predicate* predicate() const; // TODO: Protect based on being in kernel container - void setPredicate(kir::Predicate* predicate); + Expr* withPredicate(kir::Predicate* predicate); // TODO: Protect based on being in kernel container kir::Predicate* writePredicate() const; + // TODO: Protect based on being in kernel container + Expr* withWritePredicate(kir::Predicate* write_predicate); + + // TODO: Protect based on being in kernel container + void setPredicate(kir::Predicate* predicate); + // TODO: Protect based on being in kernel container void setWritePredicate(kir::Predicate* write_predicate); protected: + // TODO: Add Fusion passkey void addInput(Val* input) { TORCH_INTERNAL_ASSERT(input != nullptr); diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index aa8793366a32..658860d0ea17 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -36,6 +36,8 @@ class TORCH_CUDA_CU_API FullOp : public Expr { FullOp(const FullOp* src, IrCloner* ir_cloner); + Expr *shallowCopy() const override; + bool sameAs(const Statement* other) const override; DataType dtype() const { @@ -64,6 +66,8 @@ class TORCH_CUDA_CU_API ARangeOp : public Expr { ARangeOp(const ARangeOp* src, IrCloner* ir_cloner); + Expr *shallowCopy() const override; + bool sameAs(const Statement* other) const override; DataType dtype() const { @@ -127,6 +131,8 @@ class TORCH_CUDA_CU_API EyeOp : public Expr { EyeOp(const EyeOp* src, IrCloner* ir_cloner); + Expr *shallowCopy() const override; + bool sameAs(const Statement* other) const override; DataType dtype() const { @@ -172,6 +178,8 @@ class TORCH_CUDA_CU_API UnaryOp : public Expr { UnaryOp(const UnaryOp* src, IrCloner* ir_cloner); + Expr *shallowCopy() const override; + Val* out() const { return out_; } @@ -201,6 +209,8 @@ class TORCH_CUDA_CU_API BinaryOp : public Expr { BinaryOp(const BinaryOp* src, IrCloner* ir_cloner); + Expr *shallowCopy() const override; + Val* out() const { return out_; } @@ -239,6 +249,8 @@ class TORCH_CUDA_CU_API RNGOp : public Expr { RNGOp(const RNGOp* src, IrCloner* ir_cloner); + Expr *shallowCopy() const override; + RNGOpType getRNGOpType() const { return rng_op_type_; } @@ -298,6 +310,8 @@ class TORCH_CUDA_CU_API BroadcastOp : public Expr { BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner); + Expr *shallowCopy() const override; + Val* out() const { return out_; } @@ -346,6 +360,8 @@ class TORCH_CUDA_CU_API ReductionOp : public Expr { ReductionOp(const ReductionOp* src, IrCloner* ir_cloner); + Expr *shallowCopy() const override; + Val* out() const { return out_; } @@ -394,6 +410,8 @@ class TORCH_CUDA_CU_API GroupedReductionOp : public Expr { GroupedReductionOp(const GroupedReductionOp* src, IrCloner* ir_cloner); + Expr *shallowCopy() const override; + //! Number of expressions grouped horizontally. It does not reflect //! iteration grouping. size_t numExprs() const { @@ -580,6 +598,8 @@ class TORCH_CUDA_CU_API WelfordOp : public Expr { WelfordOp(const WelfordOp* src, IrCloner* ir_cloner); + Expr *shallowCopy() const override; + Val* out() const { return output().avg(); } @@ -675,6 +695,8 @@ class TORCH_CUDA_CU_API GroupedWelfordOp : public Expr { GroupedWelfordOp(const GroupedWelfordOp* src, IrCloner* ir_cloner); + Expr *shallowCopy() const override; + //! Number of expressions grouped horizontally. It does not reflect //! iteration grouping. As horizontal grouping is not supported, //! this always returns 1. @@ -798,6 +820,8 @@ class TORCH_CUDA_CU_API MmaOp : public Expr { MmaOp(const MmaOp* src, IrCloner* ir_cloner); + Expr *shallowCopy() const override; + Val* out() const { return out_; } @@ -856,6 +880,8 @@ class TORCH_CUDA_CU_API TransposeOp : public Expr { TransposeOp(const TransposeOp* src, IrCloner* ir_cloner); + Expr *shallowCopy() const override; + TensorView* out() const { return out_; } @@ -886,6 +912,8 @@ class TORCH_CUDA_CU_API ExpandOp : public Expr { ExpandOp(const ExpandOp* src, IrCloner* ir_cloner); + Expr *shallowCopy() const override; + TensorView* out() const { return out_; } @@ -916,6 +944,8 @@ class TORCH_CUDA_CU_API TernaryOp : public Expr { TernaryOp(const TernaryOp* src, IrCloner* ir_cloner); + Expr *shallowCopy() const override; + Val* out() const { return out_; } @@ -959,6 +989,8 @@ class TORCH_CUDA_CU_API ShiftOp : public Expr { ShiftOp(const ShiftOp* src, IrCloner* ir_cloner); + Expr *shallowCopy() const override; + Val* out() const { return out_; } @@ -1008,6 +1040,8 @@ class TORCH_CUDA_CU_API GatherOp : public Expr { GatherOp(const GatherOp* src, IrCloner* ir_cloner); + Expr *shallowCopy() const override; + Val* out() const { return out_; } @@ -1054,6 +1088,8 @@ class TORCH_CUDA_CU_API ViewAsScalar : public Expr { ViewAsScalar(const ViewAsScalar* src, IrCloner* ir_cloner); + Expr *shallowCopy() const override; + Val* out() const { return out_; } @@ -1087,6 +1123,8 @@ class TORCH_CUDA_CU_API ViewOp : public Expr { ViewOp(const ViewOp* src, IrCloner* ir_cloner); + Expr *shallowCopy() const override; + TensorView* out() const { return out_; } @@ -1112,6 +1150,8 @@ class TORCH_CUDA_CU_API LoadStoreOp : public Expr { LoadStoreOp(const LoadStoreOp* src, IrCloner* ir_cloner); + Expr *shallowCopy() const override; + Val* out() const { return out_; } @@ -1691,6 +1731,8 @@ class TORCH_CUDA_CU_API Split : public Expr { Split(const Split* src, IrCloner* ir_cloner); + Expr *shallowCopy() const override; + IterDomain* outer() const { return outer_; } @@ -1751,6 +1793,8 @@ class TORCH_CUDA_CU_API Merge : public Expr { Merge(const Merge* src, IrCloner* ir_cloner); + Expr *shallowCopy() const override; + IterDomain* out() const { return out_; } @@ -1783,6 +1827,8 @@ class TORCH_CUDA_CU_API Swizzle2D : public Expr { Swizzle2D(const Swizzle2D* src, IrCloner* ir_cloner); + Expr *shallowCopy() const override; + IterDomain* outX() const { return out_x_; } diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 410f008d59cb..6bd1c47071f3 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -200,6 +200,10 @@ FullOp::FullOp(const FullOp* src, IrCloner* ir_cloner) dtype_(src->dtype()), fill_value_(ir_cloner->clone(src->fill_value_)) {} +Expr* FullOp::shallowCopy() const { + return IrBuilder::create(output(0), fill_value_, dtype_); +} + bool FullOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -242,6 +246,41 @@ ARangeOp::ARangeOp(const ARangeOp* src, IrCloner* ir_cloner) step_(ir_cloner->clone(src->step_)), linear_index_(ir_cloner->clone(src->linear_index_)) {} +Expr* ARangeOp::shallowCopy() const { + return IrBuilder::create( + output(0), start_, end_, step_, dtype_, linear_index_); +} + +bool ARangeOp::sameAs(const Statement* other) const { + if (this == other) { + return true; + } + if (!other->isA()) { + return false; + } + const auto other_op = other->as(); + if (dtype_ != other_op->dtype_) { + return false; + } + if (!start_->sameAs(other_op->start_)) { + return false; + } + if (!end_->sameAs(other_op->end_)) { + return false; + } + if (!step_->sameAs(other_op->step_)) { + return false; + } + if ((linear_index_ == nullptr) != (other_op->linear_index_ == nullptr)) { + return false; + } + if ((linear_index_ != nullptr) && + !linear_index_->sameAs(other_op->linear_index_)) { + return false; + } + return Expr::sameAs(other); +} + EyeOp::EyeOp( IrBuilderPasskey passkey, Val* out, @@ -268,6 +307,10 @@ EyeOp::EyeOp(const EyeOp* src, IrCloner* ir_cloner) index1_(ir_cloner->clone(src->index1_)), index2_(ir_cloner->clone(src->index2_)) {} +Expr* EyeOp::shallowCopy() const { + return IrBuilder::create(output(0), dtype_, index1_, index2_); +} + bool EyeOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -294,36 +337,6 @@ bool EyeOp::sameAs(const Statement* other) const { return Expr::sameAs(other); } -bool ARangeOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - if (dtype_ != other_op->dtype_) { - return false; - } - if (!start_->sameAs(other_op->start_)) { - return false; - } - if (!end_->sameAs(other_op->end_)) { - return false; - } - if (!step_->sameAs(other_op->step_)) { - return false; - } - if ((linear_index_ == nullptr) != (other_op->linear_index_ == nullptr)) { - return false; - } - if ((linear_index_ != nullptr) && - !linear_index_->sameAs(other_op->linear_index_)) { - return false; - } - return Expr::sameAs(other); -} - UnaryOp::UnaryOp( IrBuilderPasskey passkey, UnaryOpType type, @@ -344,6 +357,10 @@ UnaryOp::UnaryOp(const UnaryOp* src, IrCloner* ir_cloner) out_(ir_cloner->clone(src->out_)), in_(ir_cloner->clone(src->in_)) {} +Expr* UnaryOp::shallowCopy() const { + return IrBuilder::create(unary_op_type_, out_, in_); +} + bool UnaryOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -381,6 +398,10 @@ BinaryOp::BinaryOp(const BinaryOp* src, IrCloner* ir_cloner) lhs_(ir_cloner->clone(src->lhs_)), rhs_(ir_cloner->clone(src->rhs_)) {} +Expr* BinaryOp::shallowCopy() const { + return IrBuilder::create(binary_op_type_, out_, lhs_, rhs_); +} + bool BinaryOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -422,6 +443,10 @@ TernaryOp::TernaryOp(const TernaryOp* src, IrCloner* ir_cloner) in2_(ir_cloner->clone(src->in2_)), in3_(ir_cloner->clone(src->in3_)) {} +Expr* TernaryOp::shallowCopy() const { + return IrBuilder::create(ternary_op_type_, out_, in1_, in2_, in3_); +} + bool TernaryOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -472,6 +497,11 @@ RNGOp::RNGOp(const RNGOp* src, IrCloner* ir_cloner) rng_offset_(src->rng_offset_), philox_index_(ir_cloner->clone(src->philox_index_)) {} +Expr* RNGOp::shallowCopy() const { + return IrBuilder::create( + rng_op_type_, output(0), dtype_, parameters_, rng_offset_, philox_index_); +} + bool RNGOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -583,6 +613,10 @@ BroadcastOp::BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner) in_(ir_cloner->clone(src->in_)), is_broadcast_dims_(src->is_broadcast_dims_) {} +Expr* BroadcastOp::shallowCopy() const { + return IrBuilder::create(out_, in_, is_broadcast_dims_); +} + bool BroadcastOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -637,6 +671,34 @@ ReductionOp::ReductionOp( addInput(in); } +ReductionOp::ReductionOp(const ReductionOp* src, IrCloner* ir_cloner) + : Expr(src, ir_cloner), + reduction_op_type_(src->reduction_op_type_), + init_(ir_cloner->clone(src->init_)), + out_(ir_cloner->clone(src->out_)), + in_(ir_cloner->clone(src->in_)), + is_allreduce_(src->is_allreduce_) {} + +Expr* ReductionOp::shallowCopy() const { + return IrBuilder::create( + reduction_op_type_, init_, out_, in_, is_allreduce_, etype()); +} + +bool ReductionOp::sameAs(const Statement* other) const { + if (this == other) { + return true; + } + if (!other->isA()) { + return false; + } + const auto other_op = other->as(); + // Note that init is not part of input vals, so it must be checked separately. + return ( + Expr::sameAs(other) && + getReductionOpType() == other_op->getReductionOpType() && + init()->sameAs(other_op->init())); +} + GroupedReductionOp::GroupedReductionOp( IrBuilderPasskey passkey, std::vector reduction_op_types, @@ -666,6 +728,16 @@ GroupedReductionOp::GroupedReductionOp( init_vals_(ir_cloner->clone(src->init_vals_)), is_allreduce_(src->is_allreduce_) {} +Expr* GroupedReductionOp::shallowCopy() const { + return IrBuilder::create( + reduction_op_types_, + init_vals_, + outputs(), + inputs(), + is_allreduce_, + etype()); +} + int GroupedReductionOp::getExprIndexOfOutput(Val* output_val) const { auto it = std::find(outputs().begin(), outputs().end(), output_val); if (it != outputs().end()) { @@ -840,6 +912,10 @@ WelfordOp::WelfordOp(const WelfordOp* src, IrCloner* ir_cloner) init_(src->init_.clone(ir_cloner)), is_allreduce_(src->is_allreduce_) {} +Expr* WelfordOp::shallowCopy() const { + return IrBuilder::create(output_, input_, init_, is_allreduce_); +} + Val* WelfordOp::getInitValOfOutput(Val* output_val) const { auto val_name = output().getNameOf(output_val); @@ -989,6 +1065,11 @@ GroupedWelfordOp::GroupedWelfordOp( init_vals_(WelfordTriplet::clone(src->init_vals_, ir_cloner)), is_allreduce_(src->is_allreduce_) {} +Expr* GroupedWelfordOp::shallowCopy() const { + return IrBuilder::create( + output_vals_, input_vals_, init_vals_, is_allreduce_, etype()); +} + bool GroupedWelfordOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -1083,6 +1164,12 @@ MmaOp::MmaOp(const MmaOp* src, IrCloner* ir_cloner) init_(ir_cloner->clone(src->init_)), options_(src->options_) {} +Expr* MmaOp::shallowCopy() const { + auto result = IrBuilder::create(out_, in_a_, in_b_, init_); + result->options_ = options_; + return result; +} + bool MmaOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -1095,29 +1182,6 @@ bool MmaOp::sameAs(const Statement* other) const { return false; } -ReductionOp::ReductionOp(const ReductionOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - reduction_op_type_(src->reduction_op_type_), - init_(ir_cloner->clone(src->init_)), - out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)), - is_allreduce_(src->is_allreduce_) {} - -bool ReductionOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - // Note that init is not part of input vals, so it must be checked separately. - return ( - Expr::sameAs(other) && - getReductionOpType() == other_op->getReductionOpType() && - init()->sameAs(other_op->init())); -} - TransposeOp::TransposeOp( IrBuilderPasskey passkey, TensorView* out, @@ -1162,6 +1226,10 @@ TransposeOp::TransposeOp(const TransposeOp* src, IrCloner* ir_cloner) in_(ir_cloner->clone(src->in_)), new2old_(src->new2old_) {} +Expr* TransposeOp::shallowCopy() const { + return IrBuilder::create(out_, in_, new2old_); +} + std::vector TransposeOp::old2new() const { std::vector old2new(new2old_.size()); for (auto new_axis : c10::irange(new2old_.size())) { @@ -1201,6 +1269,10 @@ ExpandOp::ExpandOp(const ExpandOp* src, IrCloner* ir_cloner) } } +Expr* ExpandOp::shallowCopy() const { + return IrBuilder::create(out_, in_, expanded_extents_); +} + ShiftOp::ShiftOp( IrBuilderPasskey passkey, Val* out, @@ -1248,6 +1320,10 @@ ShiftOp::ShiftOp(const ShiftOp* src, IrCloner* ir_cloner) offsets_(src->offsets_), pad_width_(src->pad_width_) {} +Expr* ShiftOp::shallowCopy() const { + return IrBuilder::create(out_, in_, offsets_, pad_width_); +} + bool ShiftOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -1310,6 +1386,10 @@ GatherOp::GatherOp(const GatherOp* src, IrCloner* ir_cloner) window_shape_(src->window_shape_), pad_width_(src->pad_width_) {} +Expr* GatherOp::shallowCopy() const { + return IrBuilder::create(out_, in_, window_shape_, pad_width_); +} + bool GatherOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -1356,6 +1436,10 @@ ViewAsScalar::ViewAsScalar(const ViewAsScalar* src, IrCloner* ir_cloner) vector_id_(ir_cloner->clone(src->vector_id_)), index_(ir_cloner->clone(src->index_)) {} +Expr* ViewAsScalar::shallowCopy() const { + return IrBuilder::create(out_, in_, vector_id_, index_); +} + ViewOp::ViewOp(IrBuilderPasskey passkey, TensorView* out, TensorView* in) : Expr(passkey, ExprType::ViewOp), out_(out), in_(in) { addOutput(out); @@ -1367,6 +1451,10 @@ ViewOp::ViewOp(const ViewOp* src, IrCloner* ir_cloner) out_(ir_cloner->clone(src->out_)), in_(ir_cloner->clone(src->in_)) {} +Expr* ViewOp::shallowCopy() const { + return IrBuilder::create(out_, in_); +} + LoadStoreOp::LoadStoreOp( IrBuilderPasskey passkey, LoadStoreOpType op_type, @@ -1386,6 +1474,10 @@ LoadStoreOp::LoadStoreOp(const LoadStoreOp* src, IrCloner* ir_cloner) out_(ir_cloner->clone(src->out_)), in_(ir_cloner->clone(src->in_)) {} +Expr* LoadStoreOp::shallowCopy() const { + return IrBuilder::create(load_store_type_, out_, in_); +} + IterDomainBuilder::IterDomainBuilder(Val* _start, Val* _extent) : start_(_start), extent_(_extent) { TORCH_INTERNAL_ASSERT( @@ -2495,6 +2587,11 @@ Split::Split(const Split* src, IrCloner* ir_cloner) start_offset_(ir_cloner->clone(src->start_offset_)), stop_offset_(ir_cloner->clone(src->stop_offset_)) {} +Expr* Split::shallowCopy() const { + return IrBuilder::create( + outer_, inner_, in_, factor_, inner_split_, start_offset_, stop_offset_); +} + Val* Split::extent(Val* in_extent, Val* start_offset, Val* stop_offset) { TORCH_INTERNAL_ASSERT(in_extent != nullptr); @@ -2540,6 +2637,10 @@ Merge::Merge(const Merge* src, IrCloner* ir_cloner) outer_(ir_cloner->clone(src->outer_)), inner_(ir_cloner->clone(src->inner_)) {} +Expr* Merge::shallowCopy() const { + return IrBuilder::create(out_, outer_, inner_); +} + bool Merge::sameAs(const Statement* other) const { if (this == other) { return true; @@ -2571,6 +2672,11 @@ Swizzle2D::Swizzle2D( addInput(in_y); } +Expr* Swizzle2D::shallowCopy() const { + return IrBuilder::create( + out_x_, out_y_, in_x_, in_y_, swizzle_type_, swizzle_mode_); +} + bool Swizzle2D::sameAs(const Statement* other) const { if (this == other) { return true; diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index 132b99b31c34..e277ecd709c3 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -78,6 +78,15 @@ TensorIndex::TensorIndex( } } +Val* TensorIndex::index(int i) const { + TORCH_INTERNAL_ASSERT( + nDims() > 0, "Tried to get an index of a 0-dim TensorIndex"); + if (i < 0) + i += nDims(); + TORCH_INTERNAL_ASSERT(i >= 0 && i < int(nDims())); + return indices_[i]; +} + BlockSync::BlockSync(IrBuilderPasskey passkey, bool war_sync) : Expr(passkey, ExprType::BlockSync), war_sync_(war_sync) { TORCH_INTERNAL_ASSERT( @@ -85,6 +94,10 @@ BlockSync::BlockSync(IrBuilderPasskey passkey, bool war_sync) "IR type only valid for Kernel container."); } +Expr* BlockSync::shallowCopy() const { + return IrBuilder::create(war_sync_); +} + GridSync::GridSync( IrBuilderPasskey passkey, ParallelTypeBitmap sync_dims, @@ -93,6 +106,10 @@ GridSync::GridSync( sync_dims_(sync_dims), sync_buffer_(sync_buffer) {} +Expr* GridSync::shallowCopy() const { + return IrBuilder::create(sync_dims_, sync_buffer_); +} + CpAsyncWait::CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages) : Expr(passkey, ExprType::CpAsyncWait), keep_stages_(keep_stages) { TORCH_INTERNAL_ASSERT( @@ -100,6 +117,10 @@ CpAsyncWait::CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages) "IR type only valid for Kernel container."); } +Expr* CpAsyncWait::shallowCopy() const { + return IrBuilder::create(keep_stages_); +} + CpAsyncCommit::CpAsyncCommit(IrBuilderPasskey passkey) : Expr(passkey, ExprType::CpAsyncCommit) { TORCH_INTERNAL_ASSERT( @@ -107,6 +128,10 @@ CpAsyncCommit::CpAsyncCommit(IrBuilderPasskey passkey) "IR type only valid for Kernel container."); } +Expr* CpAsyncCommit::shallowCopy() const { + return IrBuilder::create(); +} + InitMagicZero::InitMagicZero(IrBuilderPasskey passkey) : Expr(passkey, ExprType::InitMagicZero) { TORCH_INTERNAL_ASSERT( @@ -114,6 +139,10 @@ InitMagicZero::InitMagicZero(IrBuilderPasskey passkey) "IR type only valid for Kernel container."); } +Expr* InitMagicZero::shallowCopy() const { + return IrBuilder::create(); +} + UpdateMagicZero::UpdateMagicZero(IrBuilderPasskey passkey) : Expr(passkey, ExprType::UpdateMagicZero) { TORCH_INTERNAL_ASSERT( @@ -121,6 +150,10 @@ UpdateMagicZero::UpdateMagicZero(IrBuilderPasskey passkey) "IR type only valid for Kernel container."); } +Expr* UpdateMagicZero::shallowCopy() const { + return IrBuilder::create(); +} + namespace { bool isIntegralScalar(const Val* val) { @@ -147,6 +180,10 @@ PairSelect::PairSelect( TORCH_INTERNAL_ASSERT(isIntegralScalar(out), "Integer only for this op"); } +Expr* PairSelect::shallowCopy() const { + return IrBuilder::create(out_, in_, selection_); +} + Swizzle2DInt::Swizzle2DInt( IrBuilderPasskey passkey, IntPair* out, @@ -172,6 +209,11 @@ Swizzle2DInt::Swizzle2DInt( addInput(extent_y); } +Expr* Swizzle2DInt::shallowCopy() const { + return IrBuilder::create( + out_, in_x_, in_y_, extent_x_, extent_y_, swizzle_type_); +} + void Scope::insert(std::vector::const_iterator pos, Expr* expr) { exprs_.insert(pos, expr); } @@ -307,6 +349,21 @@ ForLoop::ForLoop(IrBuilderPasskey passkey, const ForLoop* other) "IR type only valid for Kernel container."); } +Expr* ForLoop::shallowCopy() const { + auto result = IrBuilder::create( + iter_domain_, + index_, + start_, + stop_, + step_, + vectorize_, + vectorize_shift_, + unroll_required_, + double_buffer_loop_stage_); + result->body_ = body_; + return result; +} + bool ForLoop::isUnrollable() const { // Start and stop must be constant, must not be a broadcast // dimension, cannot be bound to a parallel dimension, must not be @@ -426,13 +483,11 @@ IfThenElse::IfThenElse(IrBuilderPasskey passkey, Predicate* cond) addInput(cond); } -Val* TensorIndex::index(int i) const { - TORCH_INTERNAL_ASSERT( - nDims() > 0, "Tried to get an index of a 0-dim TensorIndex"); - if (i < 0) - i += nDims(); - TORCH_INTERNAL_ASSERT(i >= 0 && i < int(nDims())); - return indices_[i]; +Expr* IfThenElse::shallowCopy() const { + auto result = IrBuilder::create(predicate()); + result->then_body_ = then_body_; + result->else_body_ = else_body_; + return result; } Allocate::Allocate( @@ -495,6 +550,10 @@ Allocate::Allocate( "IR type only valid for Kernel container."); } +Expr* Allocate::shallowCopy() const { + return IrBuilder::create(buffer_, memory_type_, shape_, zero_init_); +} + GridReduction::GridReduction( IrBuilderPasskey passkey, BinaryOpType reduction_op_type, @@ -553,6 +612,19 @@ GroupedGridReduction::GroupedGridReduction( "IR type only valid for Kernel container."); } +Expr* GridReduction::shallowCopy() const { + return IrBuilder::create( + getReductionOpType(), + init(), + out(), + in(), + reduction_buffer_, + sync_buffer_, + entrance_index_, + entrances_, + isAllreduce()); +} + GridBroadcast::GridBroadcast( IrBuilderPasskey passkey, BroadcastOp* broadcast_op, @@ -567,6 +639,11 @@ GridBroadcast::GridBroadcast( "IR type only valid for Kernel container."); } +Expr* GridBroadcast::shallowCopy() const { + return IrBuilder::create( + broadcast_op_, broadcast_buffer_, sync_buffer_); +} + GridWelford::GridWelford( IrBuilderPasskey passkey, WelfordOp* welford_op, @@ -589,6 +666,17 @@ GridWelford::GridWelford( "IR type only valid for Kernel container."); } +Expr* GridWelford::shallowCopy() const { + return IrBuilder::create( + welford_op_, + var_buffer_, + avg_buffer_, + n_buffer_, + sync_buffer_, + entrance_index_, + entrances_); +} + GroupedGridWelford::GroupedGridWelford( IrBuilderPasskey passkey, std::vector output_vals, @@ -617,6 +705,19 @@ GroupedGridWelford::GroupedGridWelford( "IR type only valid for Kernel container."); } +Expr* GroupedGridWelford::shallowCopy() const { + return IrBuilder::create( + outputVals(), + inputVals(), + initVals(), + reduction_buffers_, + sync_buffer_, + entrance_index_, + entrances_, + buffer_stride_, + isAllreduce()); +} + AllocateFusedReduction::AllocateFusedReduction( IrBuilderPasskey passkey, GridReduction* grid_reduction) @@ -657,6 +758,24 @@ AllocateFusedReduction::AllocateFusedReduction( "IR type only valid for Kernel container."); } +Expr* AllocateFusedReduction::shallowCopy() const { + if (grid_expr_->isA()) { + return IrBuilder::create( + grid_expr_->as()); + } else if (grid_expr_->isA()) { + return IrBuilder::create( + grid_expr_->as()); + } else if (grid_expr_->isA()) { + return IrBuilder::create( + grid_expr_->as()); + } else if (grid_expr_->isA()) { + return IrBuilder::create( + grid_expr_->as()); + } + TORCH_INTERNAL_ASSERT( + false, "Unknown reduction type in AllocateFusedReduction::shallowCopy"); +} + TensorIndex* AllocateFusedReduction::out() const { TORCH_INTERNAL_ASSERT(grid_expr_ != nullptr); if (grid_expr_->isA() || diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index 62b245772dd0..c77efef67aef 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -199,6 +199,8 @@ class TORCH_CUDA_CU_API Allocate final : public Expr { Val* size, bool zero_init = false); + Expr *shallowCopy() const override; + Val* buffer() const { return buffer_; } @@ -251,6 +253,8 @@ class TORCH_CUDA_CU_API BlockSync final : public Expr { public: explicit BlockSync(IrBuilderPasskey passkey, bool war_sync = false); + Expr *shallowCopy() const override; + bool isWarHazardSync() const { return war_sync_; } @@ -265,6 +269,8 @@ class TORCH_CUDA_CU_API CpAsyncWait final : public Expr { public: explicit CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages = 0); + Expr *shallowCopy() const override; + //! Returns the remaining number of stages that are not synchronized //! after this op. unsigned int keepStages() const { @@ -282,6 +288,8 @@ class TORCH_CUDA_CU_API CpAsyncWait final : public Expr { class TORCH_CUDA_CU_API CpAsyncCommit final : public Expr { public: explicit CpAsyncCommit(IrBuilderPasskey passkey); + + Expr *shallowCopy() const override; }; // Synchronize all blocks in device, implies cooperative group launch is @@ -293,6 +301,8 @@ class TORCH_CUDA_CU_API GridSync final : public Expr { ParallelTypeBitmap sync_dims, Val* sync_buffer); + Expr *shallowCopy() const override; + ParallelTypeBitmap syncDims() const { return sync_dims_; } @@ -311,6 +321,8 @@ class TORCH_CUDA_CU_API GridSync final : public Expr { class TORCH_CUDA_CU_API InitMagicZero final : public Expr { public: explicit InitMagicZero(IrBuilderPasskey passkey); + + Expr *shallowCopy() const override; }; // Simply prints "UPDATE_MAGIC_ZERO" in the code in accordance with magic_zero @@ -318,6 +330,8 @@ class TORCH_CUDA_CU_API InitMagicZero final : public Expr { class TORCH_CUDA_CU_API UpdateMagicZero final : public Expr { public: explicit UpdateMagicZero(IrBuilderPasskey passkey); + + Expr *shallowCopy() const override; }; // TODO(kir): promote to IR node @@ -418,6 +432,8 @@ class TORCH_CUDA_CU_API ForLoop final : public Expr { ForLoop(IrBuilderPasskey passkey, const ForLoop* other); + Expr *shallowCopy() const override; + Val* index() const { return index_; } @@ -512,6 +528,8 @@ class TORCH_CUDA_CU_API IfThenElse final : public Expr { public: explicit IfThenElse(IrBuilderPasskey passkey, Predicate* cond); + Expr *shallowCopy() const override; + Scope& thenBody() { return then_body_; } @@ -557,6 +575,8 @@ class TORCH_CUDA_CU_API GridReduction final : public ReductionOp { Val* entrances, bool is_allreduce = false); + Expr *shallowCopy() const override; + Allocate* reduction_buffer() const { return reduction_buffer_; } @@ -671,6 +691,8 @@ class TORCH_CUDA_CU_API GridBroadcast final : public Expr { Allocate* broadcast_buffer, Allocate* sync_buffer); + Expr *shallowCopy() const override; + BroadcastOp* broadcast_op() const { return broadcast_op_; } @@ -710,6 +732,8 @@ class TORCH_CUDA_CU_API GridWelford final : public Expr { Val* entrance_index, Val* entrances); + Expr *shallowCopy() const override; + WelfordOp* welford_op() const { return welford_op_; } @@ -777,6 +801,8 @@ class TORCH_CUDA_CU_API GroupedGridWelford final : public GroupedWelfordOp { Val* buffer_stride, bool is_allreduce = false); + Expr *shallowCopy() const override; + const std::array, 3>& reduction_buffers() const { return reduction_buffers_; } @@ -839,6 +865,8 @@ class TORCH_CUDA_CU_API AllocateFusedReduction final : public Expr { IrBuilderPasskey passkey, GroupedGridWelford* grouped_grid_welford); + Expr *shallowCopy() const override; + Expr* gridExpr() const { return grid_expr_; } @@ -879,6 +907,8 @@ class TORCH_CUDA_CU_API PairSelect : public Expr { PairSelect(IrBuilderPasskey, Val* out, IntPair* in, Selection selection); + Expr *shallowCopy() const override; + Val* out() const { return out_; } @@ -914,6 +944,8 @@ class TORCH_CUDA_CU_API Swizzle2DInt : public Expr { Val* extent_y, Swizzle2DType swizzle_type); + Expr *shallowCopy() const override; + IntPair* out() const { return out_; } diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index e717a91819ad..654a995f0a81 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -1001,6 +1001,9 @@ struct DummyExpr : public Expr { DummyExpr& operator=(const DummyExpr& other) = delete; DummyExpr(DummyExpr&& other) = delete; DummyExpr& operator=(DummyExpr&& other) = delete; + Expr* shallowCopy() const override { + return nullptr; + } }; TEST_F(NVFuserTest, FusionTopoSort_CUDA) { From fb6e6f38f28430dbf768a2cc4a3f7647fbaea3a8 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 3 Oct 2022 00:43:21 -0700 Subject: [PATCH 02/15] withPredicate --- torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp index b29a8bc417cd..b594bc7a06d7 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp @@ -341,6 +341,12 @@ void Expr::setPredicate(kir::Predicate* predicate) { predicate_ = predicate; } +Expr* Expr::withPredicate(kir::Predicate* predicate) { + auto result = shallowCopy(); + result->setPredicate(predicate); + return result; +} + kir::Predicate* Expr::writePredicate() const { TORCH_INTERNAL_ASSERT( container()->isA(), "Function invalid for fusion."); @@ -353,6 +359,12 @@ void Expr::setWritePredicate(kir::Predicate* write_predicate) { write_predicate_ = write_predicate; } +Expr* Expr::withWritePredicate(kir::Predicate* predicate) { + auto result = shallowCopy(); + result->setWritePredicate(predicate); + return result; +} + } // namespace cuda } // namespace fuser } // namespace jit From 051f9a38a12bbc466ac285528156d1801ce19cf3 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 3 Oct 2022 01:25:49 -0700 Subject: [PATCH 03/15] setPredicate->withPredicate --- torch/csrc/jit/codegen/cuda/ir_base_nodes.h | 4 +- torch/csrc/jit/codegen/cuda/kernel_ir.h | 24 ++++--- torch/csrc/jit/codegen/cuda/lower_index.cpp | 63 ++++++++++++------- .../csrc/jit/codegen/cuda/lower_predicate.cpp | 8 +-- torch/csrc/jit/codegen/cuda/lower_shift.cpp | 11 ++-- torch/csrc/jit/codegen/cuda/lower_shift.h | 2 +- torch/csrc/jit/codegen/cuda/lower_unroll.cpp | 9 ++- torch/csrc/jit/codegen/cuda/lower_utils.cpp | 4 +- 8 files changed, 78 insertions(+), 47 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index ad3e010f5362..9bf5d0ee85c7 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -477,14 +477,14 @@ class TORCH_CUDA_CU_API Expr : public Statement { // TODO: Protect based on being in kernel container Expr* withWritePredicate(kir::Predicate* write_predicate); + protected: + // TODO: Protect based on being in kernel container void setPredicate(kir::Predicate* predicate); // TODO: Protect based on being in kernel container void setWritePredicate(kir::Predicate* write_predicate); - protected: - // TODO: Add Fusion passkey void addInput(Val* input) { TORCH_INTERNAL_ASSERT(input != nullptr); diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index c77efef67aef..0f749ef5ce4a 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -599,8 +599,10 @@ class TORCH_CUDA_CU_API GridReduction final : public ReductionOp { return thread_predicate_; } - void setThreadPredicate(const ParallelTypeBitmap& thread_predicate) { - thread_predicate_ = thread_predicate; + GridReduction* withThreadPredicate(const ParallelTypeBitmap& thread_predicate) { + auto result = shallowCopy()->as(); + result->thread_predicate_ = thread_predicate; + return result; } private: @@ -659,8 +661,10 @@ class TORCH_CUDA_CU_API GroupedGridReduction final : public GroupedReductionOp { return thread_predicate_; } - void setThreadPredicate(const ParallelTypeBitmap& thread_predicate) { - thread_predicate_ = thread_predicate; + GroupedGridReduction *withThreadPredicate(const ParallelTypeBitmap& thread_predicate) { + auto result = shallowCopy()->as(); + result->thread_predicate_ = thread_predicate; + return result; } private: @@ -768,8 +772,10 @@ class TORCH_CUDA_CU_API GridWelford final : public Expr { return thread_predicate_; } - void setThreadPredicate(const ParallelTypeBitmap& thread_predicate) { - thread_predicate_ = thread_predicate; + GridWelford* withThreadPredicate(const ParallelTypeBitmap& thread_predicate) { + auto result = shallowCopy()->as(); + result->thread_predicate_ = thread_predicate; + return result; } private: @@ -829,8 +835,10 @@ class TORCH_CUDA_CU_API GroupedGridWelford final : public GroupedWelfordOp { return thread_predicate_; } - void setThreadPredicate(const ParallelTypeBitmap& thread_predicate) { - thread_predicate_ = thread_predicate; + GroupedGridWelford* withThreadPredicate(const ParallelTypeBitmap& thread_predicate) { + auto result = shallowCopy()->as(); + result->thread_predicate_ = thread_predicate; + return result; } private: diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index f03f36bdaa71..9f72e977fcde 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -411,10 +411,12 @@ void IndexLowering::handleBlockReduction( ReductionOp* indexed_rop = IrBuilder::create( rop->getReductionOpType(), rop->init(), out, in, rop->isAllreduce()); if (rop->predicate()) { - indexed_rop->setPredicate(rop->predicate()); + indexed_rop = + indexed_rop->withPredicate(rop->predicate())->as(); } if (rop->writePredicate()) { - indexed_rop->setWritePredicate(rop->writePredicate()); + indexed_rop = indexed_rop->withWritePredicate(rop->writePredicate()) + ->as(); } pushBack(indexed_rop); @@ -493,13 +495,15 @@ void IndexLowering::handleGridReduction( n_entrances, rop->isAllreduce()); - grid_reduction->setThreadPredicate(thread_pred); + grid_reduction = grid_reduction->withThreadPredicate(thread_pred); if (rop->predicate()) { - grid_reduction->setPredicate(rop->predicate()); + grid_reduction = grid_reduction->withPredicate(rop->predicate()) + ->as(); } if (rop->writePredicate()) { - grid_reduction->setWritePredicate(rop->writePredicate()); + grid_reduction = grid_reduction->withWritePredicate(rop->writePredicate()) + ->as(); } pushBack(grid_reduction); @@ -556,10 +560,12 @@ void IndexLowering::handleBlockReduction( inputs, grouped_rop->isAllreduce()); if (grouped_rop->predicate()) { - indexed_rop->setPredicate(grouped_rop->predicate()); + indexed_rop = indexed_rop->withPredicate(grouped_rop->predicate()) + ->as(); } if (grouped_rop->writePredicate()) { - indexed_rop->setWritePredicate(grouped_rop->writePredicate()); + indexed_rop = indexed_rop->withWritePredicate(grouped_rop->writePredicate()) + ->as(); } pushBack(indexed_rop); @@ -638,13 +644,16 @@ void IndexLowering::handleGridReduction( work_buf_size_info.buffer_stride, grouped_rop->isAllreduce()); - grid_reduction->setThreadPredicate(thread_pred); + grid_reduction = grid_reduction->withThreadPredicate(thread_pred); if (grouped_rop->predicate()) { - grid_reduction->setPredicate(grouped_rop->predicate()); + grid_reduction = grid_reduction->withPredicate(grouped_rop->predicate()) + ->as(); } if (grouped_rop->writePredicate()) { - grid_reduction->setWritePredicate(grouped_rop->writePredicate()); + grid_reduction = + grid_reduction->withWritePredicate(grouped_rop->writePredicate()) + ->as(); } pushBack(grid_reduction); @@ -706,10 +715,11 @@ void IndexLowering::handle(const WelfordOp* wop) { wop->isAllreduce()); if (wop->predicate()) { - indexed_wop->setPredicate(wop->predicate()); + indexed_wop = indexed_wop->withPredicate(wop->predicate())->as(); } if (wop->writePredicate()) { - indexed_wop->setWritePredicate(wop->writePredicate()); + indexed_wop = + indexed_wop->withWritePredicate(wop->writePredicate())->as(); } // Serial welford @@ -785,22 +795,27 @@ void IndexLowering::handleGridWelford(WelfordOp* indexed_wop) { entrance_ind, n_entrances); - grid_welford->setThreadPredicate(thread_pred); + grid_welford = grid_welford->withThreadPredicate(thread_pred); const bool block_reduce_separated = out_domain->hasBlockReduction() && !indexed_wop->isAllreduce(); if (indexed_wop->predicate()) { if (block_reduce_separated) { - grid_welford->setPredicate(IrBuilder::create( - GpuLower::current()->kernel()->trueVal())); + grid_welford = grid_welford + ->withPredicate(IrBuilder::create( + GpuLower::current()->kernel()->trueVal())) + ->as(); } else { - grid_welford->setPredicate(indexed_wop->predicate()); + grid_welford = grid_welford->withPredicate(indexed_wop->predicate()) + ->as(); } } if (indexed_wop->writePredicate()) { - grid_welford->setWritePredicate(indexed_wop->writePredicate()); + grid_welford = + grid_welford->withWritePredicate(indexed_wop->writePredicate()) + ->as(); } if (block_reduce_separated) { @@ -945,13 +960,15 @@ void IndexLowering::handleGroupedGridWelford( work_buf_size_info.buffer_stride, op->isAllreduce()); - indexed_op->setThreadPredicate(thread_pred); + indexed_op = indexed_op->withThreadPredicate(thread_pred); if (op->predicate()) { - indexed_op->setPredicate(op->predicate()); + indexed_op = indexed_op->withPredicate(op->predicate()) + ->as(); } if (op->writePredicate()) { - indexed_op->setWritePredicate(op->writePredicate()); + indexed_op = indexed_op->withWritePredicate(op->writePredicate()) + ->as(); } pushBack(indexed_op); @@ -997,7 +1014,8 @@ void IndexLowering::handle(const BroadcastOp* bop) { const bool block_z = parallel_bitmap.get(ParallelType::BIDz); if (bop->predicate()) { - indexed_expr->setPredicate(bop->predicate()); + indexed_expr = + indexed_expr->withPredicate(bop->predicate())->as(); } const bool grid_broadcast_needed = block_x || block_y || block_z; @@ -1024,7 +1042,8 @@ void IndexLowering::handle(const BroadcastOp* bop) { indexed_expr, work_buffer, sync_buffer); if (bop->predicate()) { - grid_broadcast->setPredicate(bop->predicate()); + grid_broadcast = grid_broadcast->withPredicate(bop->predicate()) + ->as(); } pushBack(grid_broadcast); diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp index 989c00be81b7..73c410de741d 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp @@ -20,7 +20,7 @@ namespace cuda { namespace { -class ConditionalFromPredicateModifier : public kir::IrVisitor { +class ConditionalFromPredicateModifier : public kir::ExprMutator { public: ConditionalFromPredicateModifier() = delete; @@ -33,10 +33,10 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { ConditionalFromPredicateModifier(const std::vector& exprs) { FUSER_PERF_SCOPE( "GpuLower::Lower::ConditionalFromPredicateModifier::process"); - kir::IrVisitor::handle(exprs); + kir::ExprMutator::handle(exprs); } - using kir::IrVisitor::handle; + using kir::ExprMutator::handle; void handle(Expr* expr) final { if (expr != nullptr && expr->predicate() != nullptr) { @@ -131,7 +131,7 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { } else { // If generateConditional returns null, it means no specific // predicate needs to be used. - expr->setWritePredicate(nullptr); + registerReplace(expr, expr->withWritePredicate(nullptr)); // shallow copy clears predicates } } } diff --git a/torch/csrc/jit/codegen/cuda/lower_shift.cpp b/torch/csrc/jit/codegen/cuda/lower_shift.cpp index 991ed5745c87..2a7c04243f4c 100644 --- a/torch/csrc/jit/codegen/cuda/lower_shift.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_shift.cpp @@ -17,7 +17,7 @@ namespace jit { namespace fuser { namespace cuda { -void ShiftPredicateInserter::insert( +Expr* ShiftPredicateInserter::insert( Expr* expr, const std::vector& loops, Bool* thread_pred, @@ -30,7 +30,7 @@ void ShiftPredicateInserter::insert( const bool needs_shift_predicate = gpu_lower->haloInfo()->needsShiftPredicate(out_tv->definition()); if (!needs_shift_predicate) { - return; + return expr; } // The conditional branches to create: @@ -57,8 +57,7 @@ void ShiftPredicateInserter::insert( // the expr with shift_pred. Since the expr is not shift, the // padding is safe to omit. if (lower_utils::hasBlockSync(expr, gpu_lower->threadPredMap())) { - expr->setPredicate(shift_pred); - return; + return expr->withPredicate(shift_pred); } auto shift_ite = IrBuilder::create(shift_pred); @@ -76,7 +75,7 @@ void ShiftPredicateInserter::insert( // No padding condition is required if this is within unswitch. if (within_unswitch) { - return; + return expr; } // Padding by zero @@ -89,6 +88,8 @@ void ShiftPredicateInserter::insert( bounds_ite->thenBody().push_back(pad_expr); // Insert the else block shift_ite->elseBody().push_back(bounds_ite); + + return expr; } int AxisHaloInfo::width() const { diff --git a/torch/csrc/jit/codegen/cuda/lower_shift.h b/torch/csrc/jit/codegen/cuda/lower_shift.h index 0cb3c3ea4457..f12410703d99 100644 --- a/torch/csrc/jit/codegen/cuda/lower_shift.h +++ b/torch/csrc/jit/codegen/cuda/lower_shift.h @@ -225,7 +225,7 @@ class ShiftPredicateInserter { //! the generated predicate. The branch structure is different from //! the usual predicated expression, so the insertion is also done //! here. - static void insert( + static Expr* insert( Expr* expr, const std::vector& loops, Bool* thread_pred, diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp index 8ba3909d19e7..3996eeb5fe55 100644 --- a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp @@ -82,8 +82,11 @@ void UnrollPass::handle(Expr* expr) { // When a predicate needs to account for ShiftOp, it is currently // taken care by its own function. if (GpuLower::current()->haloInfo()->needsShiftPredicate(expr)) { - ShiftPredicateInserter::insert( + auto expr_with_predicate = ShiftPredicateInserter::insert( expr, for_loops_, thread_pred, unswitched_loop_); + if (expr_with_predicate != expr) { + registerReplace(expr, expr_with_predicate); + } return; } @@ -93,7 +96,7 @@ void UnrollPass::handle(Expr* expr) { ? thread_pred_expr : IrBuilder::create( PredicateType::ReductionWrite, expr, thread_pred); - expr->setWritePredicate(write_pred); + registerReplace(expr, expr->withWritePredicate(write_pred)); } // For expr calling a device func with block sync, don't create @@ -103,7 +106,7 @@ void UnrollPass::handle(Expr* expr) { ? thread_pred_expr : IrBuilder::create( PredicateType::Inline, expr, thread_pred); - expr->setPredicate(pred); + registerReplace(expr, expr->withPredicate(pred)); return; } diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index c3dc35c3412e..b98fbea263c3 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -446,8 +446,8 @@ class ReplaceExprInput : private kir::ExprMutator { // Copy predicates and register expression replacement void registerReplaceWithPredicate(Expr* old_expr, Expr* new_expr) { - new_expr->setPredicate(old_expr->predicate()); - new_expr->setWritePredicate(old_expr->writePredicate()); + new_expr = new_expr->withPredicate(old_expr->predicate()) + ->withWritePredicate(old_expr->writePredicate()); registerReplace(old_expr, new_expr); } From 3814a5c623cbf2f686cc65dac7e8f4f0d2e02ca0 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 3 Oct 2022 03:00:42 -0700 Subject: [PATCH 04/15] fix shallow copy --- torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 161 +++++++++++++++++++--- torch/csrc/jit/codegen/cuda/kernel_ir.cpp | 89 +++++++++--- 2 files changed, 211 insertions(+), 39 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 6bd1c47071f3..2dec991dd9a5 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -201,7 +201,12 @@ FullOp::FullOp(const FullOp* src, IrCloner* ir_cloner) fill_value_(ir_cloner->clone(src->fill_value_)) {} Expr* FullOp::shallowCopy() const { - return IrBuilder::create(output(0), fill_value_, dtype_); + auto result = IrBuilder::create(output(0), fill_value_, dtype_); + if (container()->isA()) { + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + } + return result; } bool FullOp::sameAs(const Statement* other) const { @@ -247,8 +252,13 @@ ARangeOp::ARangeOp(const ARangeOp* src, IrCloner* ir_cloner) linear_index_(ir_cloner->clone(src->linear_index_)) {} Expr* ARangeOp::shallowCopy() const { - return IrBuilder::create( + auto result = IrBuilder::create( output(0), start_, end_, step_, dtype_, linear_index_); + if (container()->isA()) { + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + } + return result; } bool ARangeOp::sameAs(const Statement* other) const { @@ -308,7 +318,12 @@ EyeOp::EyeOp(const EyeOp* src, IrCloner* ir_cloner) index2_(ir_cloner->clone(src->index2_)) {} Expr* EyeOp::shallowCopy() const { - return IrBuilder::create(output(0), dtype_, index1_, index2_); + auto result = IrBuilder::create(output(0), dtype_, index1_, index2_); + if (container()->isA()) { + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + } + return result; } bool EyeOp::sameAs(const Statement* other) const { @@ -358,7 +373,12 @@ UnaryOp::UnaryOp(const UnaryOp* src, IrCloner* ir_cloner) in_(ir_cloner->clone(src->in_)) {} Expr* UnaryOp::shallowCopy() const { - return IrBuilder::create(unary_op_type_, out_, in_); + auto result = IrBuilder::create(unary_op_type_, out_, in_); + if (container()->isA()) { + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + } + return result; } bool UnaryOp::sameAs(const Statement* other) const { @@ -399,7 +419,12 @@ BinaryOp::BinaryOp(const BinaryOp* src, IrCloner* ir_cloner) rhs_(ir_cloner->clone(src->rhs_)) {} Expr* BinaryOp::shallowCopy() const { - return IrBuilder::create(binary_op_type_, out_, lhs_, rhs_); + auto result = IrBuilder::create(binary_op_type_, out_, lhs_, rhs_); + if (container()->isA()) { + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + } + return result; } bool BinaryOp::sameAs(const Statement* other) const { @@ -444,7 +469,13 @@ TernaryOp::TernaryOp(const TernaryOp* src, IrCloner* ir_cloner) in3_(ir_cloner->clone(src->in3_)) {} Expr* TernaryOp::shallowCopy() const { - return IrBuilder::create(ternary_op_type_, out_, in1_, in2_, in3_); + auto result = + IrBuilder::create(ternary_op_type_, out_, in1_, in2_, in3_); + if (container()->isA()) { + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + } + return result; } bool TernaryOp::sameAs(const Statement* other) const { @@ -498,8 +529,13 @@ RNGOp::RNGOp(const RNGOp* src, IrCloner* ir_cloner) philox_index_(ir_cloner->clone(src->philox_index_)) {} Expr* RNGOp::shallowCopy() const { - return IrBuilder::create( + auto result = IrBuilder::create( rng_op_type_, output(0), dtype_, parameters_, rng_offset_, philox_index_); + if (container()->isA()) { + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + } + return result; } bool RNGOp::sameAs(const Statement* other) const { @@ -614,7 +650,12 @@ BroadcastOp::BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner) is_broadcast_dims_(src->is_broadcast_dims_) {} Expr* BroadcastOp::shallowCopy() const { - return IrBuilder::create(out_, in_, is_broadcast_dims_); + auto result = IrBuilder::create(out_, in_, is_broadcast_dims_); + if (container()->isA()) { + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + } + return result; } bool BroadcastOp::sameAs(const Statement* other) const { @@ -680,8 +721,13 @@ ReductionOp::ReductionOp(const ReductionOp* src, IrCloner* ir_cloner) is_allreduce_(src->is_allreduce_) {} Expr* ReductionOp::shallowCopy() const { - return IrBuilder::create( + auto result = IrBuilder::create( reduction_op_type_, init_, out_, in_, is_allreduce_, etype()); + if (container()->isA()) { + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + } + return result; } bool ReductionOp::sameAs(const Statement* other) const { @@ -729,13 +775,18 @@ GroupedReductionOp::GroupedReductionOp( is_allreduce_(src->is_allreduce_) {} Expr* GroupedReductionOp::shallowCopy() const { - return IrBuilder::create( + auto result = IrBuilder::create( reduction_op_types_, init_vals_, outputs(), inputs(), is_allreduce_, etype()); + if (container()->isA()) { + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + } + return result; } int GroupedReductionOp::getExprIndexOfOutput(Val* output_val) const { @@ -913,7 +964,13 @@ WelfordOp::WelfordOp(const WelfordOp* src, IrCloner* ir_cloner) is_allreduce_(src->is_allreduce_) {} Expr* WelfordOp::shallowCopy() const { - return IrBuilder::create(output_, input_, init_, is_allreduce_); + auto result = + IrBuilder::create(output_, input_, init_, is_allreduce_); + if (container()->isA()) { + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + } + return result; } Val* WelfordOp::getInitValOfOutput(Val* output_val) const { @@ -1066,8 +1123,13 @@ GroupedWelfordOp::GroupedWelfordOp( is_allreduce_(src->is_allreduce_) {} Expr* GroupedWelfordOp::shallowCopy() const { - return IrBuilder::create( + auto result = IrBuilder::create( output_vals_, input_vals_, init_vals_, is_allreduce_, etype()); + if (container()->isA()) { + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + } + return result; } bool GroupedWelfordOp::sameAs(const Statement* other) const { @@ -1167,6 +1229,10 @@ MmaOp::MmaOp(const MmaOp* src, IrCloner* ir_cloner) Expr* MmaOp::shallowCopy() const { auto result = IrBuilder::create(out_, in_a_, in_b_, init_); result->options_ = options_; + if (container()->isA()) { + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + } return result; } @@ -1227,7 +1293,12 @@ TransposeOp::TransposeOp(const TransposeOp* src, IrCloner* ir_cloner) new2old_(src->new2old_) {} Expr* TransposeOp::shallowCopy() const { - return IrBuilder::create(out_, in_, new2old_); + auto result = IrBuilder::create(out_, in_, new2old_); + if (container()->isA()) { + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + } + return result; } std::vector TransposeOp::old2new() const { @@ -1270,7 +1341,12 @@ ExpandOp::ExpandOp(const ExpandOp* src, IrCloner* ir_cloner) } Expr* ExpandOp::shallowCopy() const { - return IrBuilder::create(out_, in_, expanded_extents_); + auto result = IrBuilder::create(out_, in_, expanded_extents_); + if (container()->isA()) { + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + } + return result; } ShiftOp::ShiftOp( @@ -1321,7 +1397,12 @@ ShiftOp::ShiftOp(const ShiftOp* src, IrCloner* ir_cloner) pad_width_(src->pad_width_) {} Expr* ShiftOp::shallowCopy() const { - return IrBuilder::create(out_, in_, offsets_, pad_width_); + auto result = IrBuilder::create(out_, in_, offsets_, pad_width_); + if (container()->isA()) { + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + } + return result; } bool ShiftOp::sameAs(const Statement* other) const { @@ -1387,7 +1468,13 @@ GatherOp::GatherOp(const GatherOp* src, IrCloner* ir_cloner) pad_width_(src->pad_width_) {} Expr* GatherOp::shallowCopy() const { - return IrBuilder::create(out_, in_, window_shape_, pad_width_); + auto result = + IrBuilder::create(out_, in_, window_shape_, pad_width_); + if (container()->isA()) { + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + } + return result; } bool GatherOp::sameAs(const Statement* other) const { @@ -1437,7 +1524,12 @@ ViewAsScalar::ViewAsScalar(const ViewAsScalar* src, IrCloner* ir_cloner) index_(ir_cloner->clone(src->index_)) {} Expr* ViewAsScalar::shallowCopy() const { - return IrBuilder::create(out_, in_, vector_id_, index_); + auto result = IrBuilder::create(out_, in_, vector_id_, index_); + if (container()->isA()) { + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + } + return result; } ViewOp::ViewOp(IrBuilderPasskey passkey, TensorView* out, TensorView* in) @@ -1452,7 +1544,12 @@ ViewOp::ViewOp(const ViewOp* src, IrCloner* ir_cloner) in_(ir_cloner->clone(src->in_)) {} Expr* ViewOp::shallowCopy() const { - return IrBuilder::create(out_, in_); + auto result = IrBuilder::create(out_, in_); + if (container()->isA()) { + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + } + return result; } LoadStoreOp::LoadStoreOp( @@ -1475,7 +1572,12 @@ LoadStoreOp::LoadStoreOp(const LoadStoreOp* src, IrCloner* ir_cloner) in_(ir_cloner->clone(src->in_)) {} Expr* LoadStoreOp::shallowCopy() const { - return IrBuilder::create(load_store_type_, out_, in_); + auto result = IrBuilder::create(load_store_type_, out_, in_); + if (container()->isA()) { + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + } + return result; } IterDomainBuilder::IterDomainBuilder(Val* _start, Val* _extent) @@ -2588,8 +2690,13 @@ Split::Split(const Split* src, IrCloner* ir_cloner) stop_offset_(ir_cloner->clone(src->stop_offset_)) {} Expr* Split::shallowCopy() const { - return IrBuilder::create( + auto result = IrBuilder::create( outer_, inner_, in_, factor_, inner_split_, start_offset_, stop_offset_); + if (container()->isA()) { + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + } + return result; } Val* Split::extent(Val* in_extent, Val* start_offset, Val* stop_offset) { @@ -2638,7 +2745,12 @@ Merge::Merge(const Merge* src, IrCloner* ir_cloner) inner_(ir_cloner->clone(src->inner_)) {} Expr* Merge::shallowCopy() const { - return IrBuilder::create(out_, outer_, inner_); + auto result = IrBuilder::create(out_, outer_, inner_); + if (container()->isA()) { + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + } + return result; } bool Merge::sameAs(const Statement* other) const { @@ -2673,8 +2785,13 @@ Swizzle2D::Swizzle2D( } Expr* Swizzle2D::shallowCopy() const { - return IrBuilder::create( + auto result = IrBuilder::create( out_x_, out_y_, in_x_, in_y_, swizzle_type_, swizzle_mode_); + if (container()->isA()) { + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + } + return result; } bool Swizzle2D::sameAs(const Statement* other) const { diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index e277ecd709c3..e284f916faed 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -95,7 +95,10 @@ BlockSync::BlockSync(IrBuilderPasskey passkey, bool war_sync) } Expr* BlockSync::shallowCopy() const { - return IrBuilder::create(war_sync_); + auto result = IrBuilder::create(war_sync_); + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + return result; } GridSync::GridSync( @@ -107,7 +110,10 @@ GridSync::GridSync( sync_buffer_(sync_buffer) {} Expr* GridSync::shallowCopy() const { - return IrBuilder::create(sync_dims_, sync_buffer_); + auto result = IrBuilder::create(sync_dims_, sync_buffer_); + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + return result; } CpAsyncWait::CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages) @@ -118,7 +124,10 @@ CpAsyncWait::CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages) } Expr* CpAsyncWait::shallowCopy() const { - return IrBuilder::create(keep_stages_); + auto result = IrBuilder::create(keep_stages_); + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + return result; } CpAsyncCommit::CpAsyncCommit(IrBuilderPasskey passkey) @@ -129,7 +138,10 @@ CpAsyncCommit::CpAsyncCommit(IrBuilderPasskey passkey) } Expr* CpAsyncCommit::shallowCopy() const { - return IrBuilder::create(); + auto result = IrBuilder::create(); + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + return result; } InitMagicZero::InitMagicZero(IrBuilderPasskey passkey) @@ -140,7 +152,10 @@ InitMagicZero::InitMagicZero(IrBuilderPasskey passkey) } Expr* InitMagicZero::shallowCopy() const { - return IrBuilder::create(); + auto result = IrBuilder::create(); + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + return result; } UpdateMagicZero::UpdateMagicZero(IrBuilderPasskey passkey) @@ -151,7 +166,10 @@ UpdateMagicZero::UpdateMagicZero(IrBuilderPasskey passkey) } Expr* UpdateMagicZero::shallowCopy() const { - return IrBuilder::create(); + auto result = IrBuilder::create(); + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + return result; } namespace { @@ -181,7 +199,10 @@ PairSelect::PairSelect( } Expr* PairSelect::shallowCopy() const { - return IrBuilder::create(out_, in_, selection_); + auto result = IrBuilder::create(out_, in_, selection_); + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + return result; } Swizzle2DInt::Swizzle2DInt( @@ -210,8 +231,11 @@ Swizzle2DInt::Swizzle2DInt( } Expr* Swizzle2DInt::shallowCopy() const { - return IrBuilder::create( + auto result = IrBuilder::create( out_, in_x_, in_y_, extent_x_, extent_y_, swizzle_type_); + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + return result; } void Scope::insert(std::vector::const_iterator pos, Expr* expr) { @@ -361,6 +385,8 @@ Expr* ForLoop::shallowCopy() const { unroll_required_, double_buffer_loop_stage_); result->body_ = body_; + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); return result; } @@ -487,6 +513,7 @@ Expr* IfThenElse::shallowCopy() const { auto result = IrBuilder::create(predicate()); result->then_body_ = then_body_; result->else_body_ = else_body_; + result->setWritePredicate(writePredicate()); return result; } @@ -551,7 +578,11 @@ Allocate::Allocate( } Expr* Allocate::shallowCopy() const { - return IrBuilder::create(buffer_, memory_type_, shape_, zero_init_); + auto result = + IrBuilder::create(buffer_, memory_type_, shape_, zero_init_); + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + return result; } GridReduction::GridReduction( @@ -613,7 +644,7 @@ GroupedGridReduction::GroupedGridReduction( } Expr* GridReduction::shallowCopy() const { - return IrBuilder::create( + auto result = IrBuilder::create( getReductionOpType(), init(), out(), @@ -623,6 +654,9 @@ Expr* GridReduction::shallowCopy() const { entrance_index_, entrances_, isAllreduce()); + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + return result; } GridBroadcast::GridBroadcast( @@ -640,8 +674,11 @@ GridBroadcast::GridBroadcast( } Expr* GridBroadcast::shallowCopy() const { - return IrBuilder::create( + auto result = IrBuilder::create( broadcast_op_, broadcast_buffer_, sync_buffer_); + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + return result; } GridWelford::GridWelford( @@ -667,7 +704,7 @@ GridWelford::GridWelford( } Expr* GridWelford::shallowCopy() const { - return IrBuilder::create( + auto result = IrBuilder::create( welford_op_, var_buffer_, avg_buffer_, @@ -675,6 +712,9 @@ Expr* GridWelford::shallowCopy() const { sync_buffer_, entrance_index_, entrances_); + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + return result; } GroupedGridWelford::GroupedGridWelford( @@ -706,7 +746,7 @@ GroupedGridWelford::GroupedGridWelford( } Expr* GroupedGridWelford::shallowCopy() const { - return IrBuilder::create( + auto result = IrBuilder::create( outputVals(), inputVals(), initVals(), @@ -716,6 +756,9 @@ Expr* GroupedGridWelford::shallowCopy() const { entrances_, buffer_stride_, isAllreduce()); + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + return result; } AllocateFusedReduction::AllocateFusedReduction( @@ -760,17 +803,29 @@ AllocateFusedReduction::AllocateFusedReduction( Expr* AllocateFusedReduction::shallowCopy() const { if (grid_expr_->isA()) { - return IrBuilder::create( + auto result = IrBuilder::create( grid_expr_->as()); + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + return result; } else if (grid_expr_->isA()) { - return IrBuilder::create( + auto result = IrBuilder::create( grid_expr_->as()); + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + return result; } else if (grid_expr_->isA()) { - return IrBuilder::create( + auto result = IrBuilder::create( grid_expr_->as()); + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + return result; } else if (grid_expr_->isA()) { - return IrBuilder::create( + auto result = IrBuilder::create( grid_expr_->as()); + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + return result; } TORCH_INTERNAL_ASSERT( false, "Unknown reduction type in AllocateFusedReduction::shallowCopy"); From 41748e4893ab1367ce8c680f155ca9d0222d20ee Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 3 Oct 2022 11:40:34 -0700 Subject: [PATCH 05/15] fix --- torch/csrc/jit/codegen/cuda/kernel_ir.cpp | 31 ++++++++++++++++++----- torch/csrc/jit/codegen/cuda/kernel_ir.h | 2 ++ 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index e284f916faed..5d05c5bcd5ad 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -613,6 +613,22 @@ GridReduction::GridReduction( "IR type only valid for Kernel container."); } +Expr* GridReduction::shallowCopy() const { + auto result = IrBuilder::create( + getReductionOpType(), + init(), + out(), + in(), + reduction_buffer_, + sync_buffer_, + entrance_index_, + entrances_, + isAllreduce()); + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + return result; +} + GroupedGridReduction::GroupedGridReduction( IrBuilderPasskey passkey, std::vector reduction_op_types, @@ -643,16 +659,17 @@ GroupedGridReduction::GroupedGridReduction( "IR type only valid for Kernel container."); } -Expr* GridReduction::shallowCopy() const { - auto result = IrBuilder::create( - getReductionOpType(), - init(), - out(), - in(), - reduction_buffer_, +Expr* GroupedGridReduction::shallowCopy() const { + auto result = IrBuilder::create( + getReductionOpTypes(), + initVals(), + outputs(), + inputs(), + reduction_buffers_, sync_buffer_, entrance_index_, entrances_, + buffer_stride_, isAllreduce()); result->setPredicate(predicate()); result->setWritePredicate(writePredicate()); diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index 0f749ef5ce4a..030580cf50e6 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -631,6 +631,8 @@ class TORCH_CUDA_CU_API GroupedGridReduction final : public GroupedReductionOp { Val* buffer_stride, bool is_allreduce = false); + Expr *shallowCopy() const override; + const std::vector& reduction_buffers() const { return reduction_buffers_; } From 42f72875ca97bea44fa7047ac2989318ff41b77c Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 3 Oct 2022 12:03:47 -0700 Subject: [PATCH 06/15] fix thread_predicate_ --- torch/csrc/jit/codegen/cuda/kernel_ir.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index 5d05c5bcd5ad..8c8cb29dbb78 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -626,6 +626,7 @@ Expr* GridReduction::shallowCopy() const { isAllreduce()); result->setPredicate(predicate()); result->setWritePredicate(writePredicate()); + result->thread_predicate_ = thread_predicate_; return result; } @@ -673,6 +674,7 @@ Expr* GroupedGridReduction::shallowCopy() const { isAllreduce()); result->setPredicate(predicate()); result->setWritePredicate(writePredicate()); + result->thread_predicate_ = thread_predicate_; return result; } @@ -731,6 +733,7 @@ Expr* GridWelford::shallowCopy() const { entrances_); result->setPredicate(predicate()); result->setWritePredicate(writePredicate()); + result->thread_predicate_ = thread_predicate_; return result; } @@ -775,6 +778,7 @@ Expr* GroupedGridWelford::shallowCopy() const { isAllreduce()); result->setPredicate(predicate()); result->setWritePredicate(writePredicate()); + result->thread_predicate_ = thread_predicate_; return result; } From 389892df7015507c4d5efecb1b4429a645cd25cc Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 3 Oct 2022 12:59:24 -0700 Subject: [PATCH 07/15] Revert "setPredicate->withPredicate" This reverts commit 051f9a38a12bbc466ac285528156d1801ce19cf3. --- torch/csrc/jit/codegen/cuda/ir_base_nodes.h | 4 ++-- torch/csrc/jit/codegen/cuda/lower_predicate.cpp | 8 ++++---- torch/csrc/jit/codegen/cuda/lower_shift.cpp | 11 +++++------ torch/csrc/jit/codegen/cuda/lower_shift.h | 2 +- torch/csrc/jit/codegen/cuda/lower_unroll.cpp | 9 +++------ 5 files changed, 15 insertions(+), 19 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index 9bf5d0ee85c7..ad3e010f5362 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -477,14 +477,14 @@ class TORCH_CUDA_CU_API Expr : public Statement { // TODO: Protect based on being in kernel container Expr* withWritePredicate(kir::Predicate* write_predicate); - protected: - // TODO: Protect based on being in kernel container void setPredicate(kir::Predicate* predicate); // TODO: Protect based on being in kernel container void setWritePredicate(kir::Predicate* write_predicate); + protected: + // TODO: Add Fusion passkey void addInput(Val* input) { TORCH_INTERNAL_ASSERT(input != nullptr); diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp index 73c410de741d..989c00be81b7 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp @@ -20,7 +20,7 @@ namespace cuda { namespace { -class ConditionalFromPredicateModifier : public kir::ExprMutator { +class ConditionalFromPredicateModifier : public kir::IrVisitor { public: ConditionalFromPredicateModifier() = delete; @@ -33,10 +33,10 @@ class ConditionalFromPredicateModifier : public kir::ExprMutator { ConditionalFromPredicateModifier(const std::vector& exprs) { FUSER_PERF_SCOPE( "GpuLower::Lower::ConditionalFromPredicateModifier::process"); - kir::ExprMutator::handle(exprs); + kir::IrVisitor::handle(exprs); } - using kir::ExprMutator::handle; + using kir::IrVisitor::handle; void handle(Expr* expr) final { if (expr != nullptr && expr->predicate() != nullptr) { @@ -131,7 +131,7 @@ class ConditionalFromPredicateModifier : public kir::ExprMutator { } else { // If generateConditional returns null, it means no specific // predicate needs to be used. - registerReplace(expr, expr->withWritePredicate(nullptr)); // shallow copy clears predicates + expr->setWritePredicate(nullptr); } } } diff --git a/torch/csrc/jit/codegen/cuda/lower_shift.cpp b/torch/csrc/jit/codegen/cuda/lower_shift.cpp index 2a7c04243f4c..991ed5745c87 100644 --- a/torch/csrc/jit/codegen/cuda/lower_shift.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_shift.cpp @@ -17,7 +17,7 @@ namespace jit { namespace fuser { namespace cuda { -Expr* ShiftPredicateInserter::insert( +void ShiftPredicateInserter::insert( Expr* expr, const std::vector& loops, Bool* thread_pred, @@ -30,7 +30,7 @@ Expr* ShiftPredicateInserter::insert( const bool needs_shift_predicate = gpu_lower->haloInfo()->needsShiftPredicate(out_tv->definition()); if (!needs_shift_predicate) { - return expr; + return; } // The conditional branches to create: @@ -57,7 +57,8 @@ Expr* ShiftPredicateInserter::insert( // the expr with shift_pred. Since the expr is not shift, the // padding is safe to omit. if (lower_utils::hasBlockSync(expr, gpu_lower->threadPredMap())) { - return expr->withPredicate(shift_pred); + expr->setPredicate(shift_pred); + return; } auto shift_ite = IrBuilder::create(shift_pred); @@ -75,7 +76,7 @@ Expr* ShiftPredicateInserter::insert( // No padding condition is required if this is within unswitch. if (within_unswitch) { - return expr; + return; } // Padding by zero @@ -88,8 +89,6 @@ Expr* ShiftPredicateInserter::insert( bounds_ite->thenBody().push_back(pad_expr); // Insert the else block shift_ite->elseBody().push_back(bounds_ite); - - return expr; } int AxisHaloInfo::width() const { diff --git a/torch/csrc/jit/codegen/cuda/lower_shift.h b/torch/csrc/jit/codegen/cuda/lower_shift.h index f12410703d99..0cb3c3ea4457 100644 --- a/torch/csrc/jit/codegen/cuda/lower_shift.h +++ b/torch/csrc/jit/codegen/cuda/lower_shift.h @@ -225,7 +225,7 @@ class ShiftPredicateInserter { //! the generated predicate. The branch structure is different from //! the usual predicated expression, so the insertion is also done //! here. - static Expr* insert( + static void insert( Expr* expr, const std::vector& loops, Bool* thread_pred, diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp index 3996eeb5fe55..8ba3909d19e7 100644 --- a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp @@ -82,11 +82,8 @@ void UnrollPass::handle(Expr* expr) { // When a predicate needs to account for ShiftOp, it is currently // taken care by its own function. if (GpuLower::current()->haloInfo()->needsShiftPredicate(expr)) { - auto expr_with_predicate = ShiftPredicateInserter::insert( + ShiftPredicateInserter::insert( expr, for_loops_, thread_pred, unswitched_loop_); - if (expr_with_predicate != expr) { - registerReplace(expr, expr_with_predicate); - } return; } @@ -96,7 +93,7 @@ void UnrollPass::handle(Expr* expr) { ? thread_pred_expr : IrBuilder::create( PredicateType::ReductionWrite, expr, thread_pred); - registerReplace(expr, expr->withWritePredicate(write_pred)); + expr->setWritePredicate(write_pred); } // For expr calling a device func with block sync, don't create @@ -106,7 +103,7 @@ void UnrollPass::handle(Expr* expr) { ? thread_pred_expr : IrBuilder::create( PredicateType::Inline, expr, thread_pred); - registerReplace(expr, expr->withPredicate(pred)); + expr->setPredicate(pred); return; } From c18e8cb5c6efedd50b4e6b403f5fadebc77b1e9e Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 3 Oct 2022 16:11:16 -0700 Subject: [PATCH 08/15] ir_utils::toString --- torch/csrc/jit/codegen/cuda/ir_utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.h b/torch/csrc/jit/codegen/cuda/ir_utils.h index ad9f67180f92..adfc64fc74ad 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.h +++ b/torch/csrc/jit/codegen/cuda/ir_utils.h @@ -325,7 +325,7 @@ TORCH_CUDA_CU_API std::vector getViewOps(Fusion*); template std::string toString(const T& nodes) { std::stringstream ss; - for (Statement* stmt : nodes) { + for (const Statement* stmt : nodes) { if (ss.tellp() != 0) { ss << ", "; } From 11856fc0795343f282d2971b7f0fcdf7e09de95d Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 3 Oct 2022 16:33:07 -0700 Subject: [PATCH 09/15] unroll --- torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp | 6 +++ torch/csrc/jit/codegen/cuda/kernel_ir.cpp | 9 ++++ torch/csrc/jit/codegen/cuda/kernel_ir.h | 47 +++++++++++-------- torch/csrc/jit/codegen/cuda/lower_shift.cpp | 11 +++-- torch/csrc/jit/codegen/cuda/lower_shift.h | 2 +- torch/csrc/jit/codegen/cuda/lower_unroll.cpp | 46 +++++++++++++----- torch/csrc/jit/codegen/cuda/lower_unroll.h | 3 +- 7 files changed, 85 insertions(+), 39 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp index b594bc7a06d7..2f04a6fab1ac 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp @@ -343,6 +343,9 @@ void Expr::setPredicate(kir::Predicate* predicate) { Expr* Expr::withPredicate(kir::Predicate* predicate) { auto result = shallowCopy(); + if (predicate != nullptr) { + predicate = predicate->maybeWithReplacedExpr(this, result); + } result->setPredicate(predicate); return result; } @@ -361,6 +364,9 @@ void Expr::setWritePredicate(kir::Predicate* write_predicate) { Expr* Expr::withWritePredicate(kir::Predicate* predicate) { auto result = shallowCopy(); + if (predicate != nullptr) { + predicate = predicate->maybeWithReplacedExpr(this, result); + } result->setWritePredicate(predicate); return result; } diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index 8c8cb29dbb78..e5aec63870b0 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -50,6 +50,15 @@ Predicate::Predicate(IrBuilderPasskey passkey, Bool* value) TORCH_INTERNAL_ASSERT(value != nullptr); } +Predicate* Predicate::maybeWithReplacedExpr( + const Expr* expr, + const Expr* new_expr) { + if (expr_ == expr) { + return IrBuilder::create(ptype_, new_expr, thread_pred_); + } + return this; +} + TensorIndex::TensorIndex( IrBuilderPasskey passkey, const TensorView* view, diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index 030580cf50e6..a59cd53c4159 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -83,6 +83,10 @@ class TORCH_CUDA_CU_API Predicate final : public Val { explicit Predicate(IrBuilderPasskey passkey, Bool* value); + Predicate* maybeWithReplacedExpr( + const Expr* expr, + const Expr* new_expr); // TODO const + PredicateType predicate_type() const { return ptype_; } @@ -199,7 +203,7 @@ class TORCH_CUDA_CU_API Allocate final : public Expr { Val* size, bool zero_init = false); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; Val* buffer() const { return buffer_; @@ -253,7 +257,7 @@ class TORCH_CUDA_CU_API BlockSync final : public Expr { public: explicit BlockSync(IrBuilderPasskey passkey, bool war_sync = false); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; bool isWarHazardSync() const { return war_sync_; @@ -269,7 +273,7 @@ class TORCH_CUDA_CU_API CpAsyncWait final : public Expr { public: explicit CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages = 0); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; //! Returns the remaining number of stages that are not synchronized //! after this op. @@ -289,7 +293,7 @@ class TORCH_CUDA_CU_API CpAsyncCommit final : public Expr { public: explicit CpAsyncCommit(IrBuilderPasskey passkey); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; }; // Synchronize all blocks in device, implies cooperative group launch is @@ -301,7 +305,7 @@ class TORCH_CUDA_CU_API GridSync final : public Expr { ParallelTypeBitmap sync_dims, Val* sync_buffer); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; ParallelTypeBitmap syncDims() const { return sync_dims_; @@ -322,7 +326,7 @@ class TORCH_CUDA_CU_API InitMagicZero final : public Expr { public: explicit InitMagicZero(IrBuilderPasskey passkey); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; }; // Simply prints "UPDATE_MAGIC_ZERO" in the code in accordance with magic_zero @@ -331,7 +335,7 @@ class TORCH_CUDA_CU_API UpdateMagicZero final : public Expr { public: explicit UpdateMagicZero(IrBuilderPasskey passkey); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; }; // TODO(kir): promote to IR node @@ -432,7 +436,7 @@ class TORCH_CUDA_CU_API ForLoop final : public Expr { ForLoop(IrBuilderPasskey passkey, const ForLoop* other); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; Val* index() const { return index_; @@ -528,7 +532,7 @@ class TORCH_CUDA_CU_API IfThenElse final : public Expr { public: explicit IfThenElse(IrBuilderPasskey passkey, Predicate* cond); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; Scope& thenBody() { return then_body_; @@ -575,7 +579,7 @@ class TORCH_CUDA_CU_API GridReduction final : public ReductionOp { Val* entrances, bool is_allreduce = false); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; Allocate* reduction_buffer() const { return reduction_buffer_; @@ -599,7 +603,8 @@ class TORCH_CUDA_CU_API GridReduction final : public ReductionOp { return thread_predicate_; } - GridReduction* withThreadPredicate(const ParallelTypeBitmap& thread_predicate) { + GridReduction* withThreadPredicate( + const ParallelTypeBitmap& thread_predicate) { auto result = shallowCopy()->as(); result->thread_predicate_ = thread_predicate; return result; @@ -631,7 +636,7 @@ class TORCH_CUDA_CU_API GroupedGridReduction final : public GroupedReductionOp { Val* buffer_stride, bool is_allreduce = false); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; const std::vector& reduction_buffers() const { return reduction_buffers_; @@ -663,7 +668,8 @@ class TORCH_CUDA_CU_API GroupedGridReduction final : public GroupedReductionOp { return thread_predicate_; } - GroupedGridReduction *withThreadPredicate(const ParallelTypeBitmap& thread_predicate) { + GroupedGridReduction* withThreadPredicate( + const ParallelTypeBitmap& thread_predicate) { auto result = shallowCopy()->as(); result->thread_predicate_ = thread_predicate; return result; @@ -697,7 +703,7 @@ class TORCH_CUDA_CU_API GridBroadcast final : public Expr { Allocate* broadcast_buffer, Allocate* sync_buffer); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; BroadcastOp* broadcast_op() const { return broadcast_op_; @@ -738,7 +744,7 @@ class TORCH_CUDA_CU_API GridWelford final : public Expr { Val* entrance_index, Val* entrances); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; WelfordOp* welford_op() const { return welford_op_; @@ -809,7 +815,7 @@ class TORCH_CUDA_CU_API GroupedGridWelford final : public GroupedWelfordOp { Val* buffer_stride, bool is_allreduce = false); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; const std::array, 3>& reduction_buffers() const { return reduction_buffers_; @@ -837,7 +843,8 @@ class TORCH_CUDA_CU_API GroupedGridWelford final : public GroupedWelfordOp { return thread_predicate_; } - GroupedGridWelford* withThreadPredicate(const ParallelTypeBitmap& thread_predicate) { + GroupedGridWelford* withThreadPredicate( + const ParallelTypeBitmap& thread_predicate) { auto result = shallowCopy()->as(); result->thread_predicate_ = thread_predicate; return result; @@ -875,7 +882,7 @@ class TORCH_CUDA_CU_API AllocateFusedReduction final : public Expr { IrBuilderPasskey passkey, GroupedGridWelford* grouped_grid_welford); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; Expr* gridExpr() const { return grid_expr_; @@ -917,7 +924,7 @@ class TORCH_CUDA_CU_API PairSelect : public Expr { PairSelect(IrBuilderPasskey, Val* out, IntPair* in, Selection selection); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; Val* out() const { return out_; @@ -954,7 +961,7 @@ class TORCH_CUDA_CU_API Swizzle2DInt : public Expr { Val* extent_y, Swizzle2DType swizzle_type); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; IntPair* out() const { return out_; diff --git a/torch/csrc/jit/codegen/cuda/lower_shift.cpp b/torch/csrc/jit/codegen/cuda/lower_shift.cpp index 991ed5745c87..2a7c04243f4c 100644 --- a/torch/csrc/jit/codegen/cuda/lower_shift.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_shift.cpp @@ -17,7 +17,7 @@ namespace jit { namespace fuser { namespace cuda { -void ShiftPredicateInserter::insert( +Expr* ShiftPredicateInserter::insert( Expr* expr, const std::vector& loops, Bool* thread_pred, @@ -30,7 +30,7 @@ void ShiftPredicateInserter::insert( const bool needs_shift_predicate = gpu_lower->haloInfo()->needsShiftPredicate(out_tv->definition()); if (!needs_shift_predicate) { - return; + return expr; } // The conditional branches to create: @@ -57,8 +57,7 @@ void ShiftPredicateInserter::insert( // the expr with shift_pred. Since the expr is not shift, the // padding is safe to omit. if (lower_utils::hasBlockSync(expr, gpu_lower->threadPredMap())) { - expr->setPredicate(shift_pred); - return; + return expr->withPredicate(shift_pred); } auto shift_ite = IrBuilder::create(shift_pred); @@ -76,7 +75,7 @@ void ShiftPredicateInserter::insert( // No padding condition is required if this is within unswitch. if (within_unswitch) { - return; + return expr; } // Padding by zero @@ -89,6 +88,8 @@ void ShiftPredicateInserter::insert( bounds_ite->thenBody().push_back(pad_expr); // Insert the else block shift_ite->elseBody().push_back(bounds_ite); + + return expr; } int AxisHaloInfo::width() const { diff --git a/torch/csrc/jit/codegen/cuda/lower_shift.h b/torch/csrc/jit/codegen/cuda/lower_shift.h index 0cb3c3ea4457..f12410703d99 100644 --- a/torch/csrc/jit/codegen/cuda/lower_shift.h +++ b/torch/csrc/jit/codegen/cuda/lower_shift.h @@ -225,7 +225,7 @@ class ShiftPredicateInserter { //! the generated predicate. The branch structure is different from //! the usual predicated expression, so the insertion is also done //! here. - static void insert( + static Expr* insert( Expr* expr, const std::vector& loops, Bool* thread_pred, diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp index 8ba3909d19e7..78409ba49cc8 100644 --- a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp @@ -54,6 +54,14 @@ bool isReductionInitExpr(const Expr* expr) { } // namespace +void UnrollPass::registerReplace( + Expr* reference, + Expr* new_expr, + kir::Scope* scope) { + kir::ExprMutator::registerReplace(reference, new_expr, scope); + GpuLower::current()->propagateExprInfo(reference, new_expr); +} + void UnrollPass::handle(Expr* expr) { if (ir_utils::isTvOp(expr)) { // If tv op, predicate it @@ -79,31 +87,41 @@ void UnrollPass::handle(Expr* expr) { non_trivial_pred_found_ = true; + Expr* expr_with_predicate = expr; + // When a predicate needs to account for ShiftOp, it is currently // taken care by its own function. if (GpuLower::current()->haloInfo()->needsShiftPredicate(expr)) { - ShiftPredicateInserter::insert( - expr, for_loops_, thread_pred, unswitched_loop_); + expr_with_predicate = ShiftPredicateInserter::insert( + expr_with_predicate, for_loops_, thread_pred, unswitched_loop_); + if (expr_with_predicate != expr) { + registerReplace(expr, expr_with_predicate, &for_loops_.back()->body()); + } return; } // Reduction may need a separate predicate for writes. - if (!isReductionInitExpr(expr) && out_tv->domain()->hasReduction()) { + if (!isReductionInitExpr(expr_with_predicate) && + out_tv->domain()->hasReduction()) { const auto write_pred = unswitched_loop_ ? thread_pred_expr : IrBuilder::create( - PredicateType::ReductionWrite, expr, thread_pred); - expr->setWritePredicate(write_pred); + PredicateType::ReductionWrite, + expr_with_predicate, + thread_pred); + expr_with_predicate = expr_with_predicate->withWritePredicate(write_pred); } // For expr calling a device func with block sync, don't create // if-then-else but pass the predicate to the device func - if (lower_utils::hasBlockSync(expr, GpuLower::current()->threadPredMap())) { + if (lower_utils::hasBlockSync( + expr_with_predicate, GpuLower::current()->threadPredMap())) { const auto pred = unswitched_loop_ ? thread_pred_expr : IrBuilder::create( - PredicateType::Inline, expr, thread_pred); - expr->setPredicate(pred); + PredicateType::Inline, expr_with_predicate, thread_pred); + expr_with_predicate = expr_with_predicate->withPredicate(pred); + registerReplace(expr, expr_with_predicate, &for_loops_.back()->body()); return; } @@ -119,9 +137,10 @@ void UnrollPass::handle(Expr* expr) { } if (pred == nullptr) { - pred = unswitched_loop_ ? thread_pred_expr - : IrBuilder::create( - PredicateType::Inline, expr, thread_pred); + pred = unswitched_loop_ + ? thread_pred_expr + : IrBuilder::create( + PredicateType::Inline, expr_with_predicate, thread_pred); } // If we need a predicate, put expr inside an if then else @@ -135,7 +154,10 @@ void UnrollPass::handle(Expr* expr) { kir::ExprMutator::registerReplace( expr, inline_ite, &for_loops_.back()->body()); } - inline_ite->thenBody().push_back(expr); + if (expr != expr_with_predicate) { + GpuLower::current()->propagateExprInfo(expr, expr_with_predicate); + } + inline_ite->thenBody().push_back(expr_with_predicate); } else if (auto for_loop = dynamic_cast(expr)) { handle(for_loop); } diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.h b/torch/csrc/jit/codegen/cuda/lower_unroll.h index 14725c405b77..14280def7643 100644 --- a/torch/csrc/jit/codegen/cuda/lower_unroll.h +++ b/torch/csrc/jit/codegen/cuda/lower_unroll.h @@ -62,6 +62,8 @@ class TORCH_CUDA_CU_API UnrollPass : kir::ExprMutator { static bool canOmitElseClause(kir::ForLoop* fl); private: + void registerReplace(Expr* reference, Expr* new_expr, kir::Scope* scope); + // Generate the for Expr replacement map UnrollPass(const std::vector& exprs); @@ -75,7 +77,6 @@ class TORCH_CUDA_CU_API UnrollPass : kir::ExprMutator { void handle(Expr* expr) final; - private: // We will track which loops in the incoming IR will be replaced and by what std::unordered_map expr_replacement_map_; From ddb03b90b1026678fd267408c7ec9835174b3ba3 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 3 Oct 2022 17:01:36 -0700 Subject: [PATCH 10/15] lower predicate --- torch/csrc/jit/codegen/cuda/ir_base_nodes.h | 5 +- .../csrc/jit/codegen/cuda/ir_internal_nodes.h | 46 +++++++++---------- torch/csrc/jit/codegen/cuda/kernel_ir.h | 2 +- .../csrc/jit/codegen/cuda/lower_predicate.cpp | 18 ++++---- 4 files changed, 35 insertions(+), 36 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index ad3e010f5362..781cf2145813 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -426,7 +426,7 @@ class TORCH_CUDA_CU_API Expr : public Statement { Expr(const Expr* src, IrCloner* ir_cloner); - virtual Expr *shallowCopy() const = 0; + virtual Expr* shallowCopy() const = 0; c10::optional getExprType() const override { return etype_; @@ -477,14 +477,13 @@ class TORCH_CUDA_CU_API Expr : public Statement { // TODO: Protect based on being in kernel container Expr* withWritePredicate(kir::Predicate* write_predicate); + protected: // TODO: Protect based on being in kernel container void setPredicate(kir::Predicate* predicate); // TODO: Protect based on being in kernel container void setWritePredicate(kir::Predicate* write_predicate); - protected: - // TODO: Add Fusion passkey void addInput(Val* input) { TORCH_INTERNAL_ASSERT(input != nullptr); diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index 658860d0ea17..74c055c3faac 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -36,7 +36,7 @@ class TORCH_CUDA_CU_API FullOp : public Expr { FullOp(const FullOp* src, IrCloner* ir_cloner); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; bool sameAs(const Statement* other) const override; @@ -66,7 +66,7 @@ class TORCH_CUDA_CU_API ARangeOp : public Expr { ARangeOp(const ARangeOp* src, IrCloner* ir_cloner); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; bool sameAs(const Statement* other) const override; @@ -131,7 +131,7 @@ class TORCH_CUDA_CU_API EyeOp : public Expr { EyeOp(const EyeOp* src, IrCloner* ir_cloner); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; bool sameAs(const Statement* other) const override; @@ -178,7 +178,7 @@ class TORCH_CUDA_CU_API UnaryOp : public Expr { UnaryOp(const UnaryOp* src, IrCloner* ir_cloner); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; Val* out() const { return out_; @@ -209,7 +209,7 @@ class TORCH_CUDA_CU_API BinaryOp : public Expr { BinaryOp(const BinaryOp* src, IrCloner* ir_cloner); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; Val* out() const { return out_; @@ -249,7 +249,7 @@ class TORCH_CUDA_CU_API RNGOp : public Expr { RNGOp(const RNGOp* src, IrCloner* ir_cloner); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; RNGOpType getRNGOpType() const { return rng_op_type_; @@ -310,7 +310,7 @@ class TORCH_CUDA_CU_API BroadcastOp : public Expr { BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; Val* out() const { return out_; @@ -360,7 +360,7 @@ class TORCH_CUDA_CU_API ReductionOp : public Expr { ReductionOp(const ReductionOp* src, IrCloner* ir_cloner); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; Val* out() const { return out_; @@ -410,7 +410,7 @@ class TORCH_CUDA_CU_API GroupedReductionOp : public Expr { GroupedReductionOp(const GroupedReductionOp* src, IrCloner* ir_cloner); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; //! Number of expressions grouped horizontally. It does not reflect //! iteration grouping. @@ -598,7 +598,7 @@ class TORCH_CUDA_CU_API WelfordOp : public Expr { WelfordOp(const WelfordOp* src, IrCloner* ir_cloner); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; Val* out() const { return output().avg(); @@ -695,7 +695,7 @@ class TORCH_CUDA_CU_API GroupedWelfordOp : public Expr { GroupedWelfordOp(const GroupedWelfordOp* src, IrCloner* ir_cloner); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; //! Number of expressions grouped horizontally. It does not reflect //! iteration grouping. As horizontal grouping is not supported, @@ -820,7 +820,7 @@ class TORCH_CUDA_CU_API MmaOp : public Expr { MmaOp(const MmaOp* src, IrCloner* ir_cloner); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; Val* out() const { return out_; @@ -880,7 +880,7 @@ class TORCH_CUDA_CU_API TransposeOp : public Expr { TransposeOp(const TransposeOp* src, IrCloner* ir_cloner); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; TensorView* out() const { return out_; @@ -912,7 +912,7 @@ class TORCH_CUDA_CU_API ExpandOp : public Expr { ExpandOp(const ExpandOp* src, IrCloner* ir_cloner); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; TensorView* out() const { return out_; @@ -944,7 +944,7 @@ class TORCH_CUDA_CU_API TernaryOp : public Expr { TernaryOp(const TernaryOp* src, IrCloner* ir_cloner); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; Val* out() const { return out_; @@ -989,7 +989,7 @@ class TORCH_CUDA_CU_API ShiftOp : public Expr { ShiftOp(const ShiftOp* src, IrCloner* ir_cloner); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; Val* out() const { return out_; @@ -1040,7 +1040,7 @@ class TORCH_CUDA_CU_API GatherOp : public Expr { GatherOp(const GatherOp* src, IrCloner* ir_cloner); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; Val* out() const { return out_; @@ -1088,7 +1088,7 @@ class TORCH_CUDA_CU_API ViewAsScalar : public Expr { ViewAsScalar(const ViewAsScalar* src, IrCloner* ir_cloner); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; Val* out() const { return out_; @@ -1123,7 +1123,7 @@ class TORCH_CUDA_CU_API ViewOp : public Expr { ViewOp(const ViewOp* src, IrCloner* ir_cloner); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; TensorView* out() const { return out_; @@ -1150,7 +1150,7 @@ class TORCH_CUDA_CU_API LoadStoreOp : public Expr { LoadStoreOp(const LoadStoreOp* src, IrCloner* ir_cloner); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; Val* out() const { return out_; @@ -1731,7 +1731,7 @@ class TORCH_CUDA_CU_API Split : public Expr { Split(const Split* src, IrCloner* ir_cloner); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; IterDomain* outer() const { return outer_; @@ -1793,7 +1793,7 @@ class TORCH_CUDA_CU_API Merge : public Expr { Merge(const Merge* src, IrCloner* ir_cloner); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; IterDomain* out() const { return out_; @@ -1827,7 +1827,7 @@ class TORCH_CUDA_CU_API Swizzle2D : public Expr { Swizzle2D(const Swizzle2D* src, IrCloner* ir_cloner); - Expr *shallowCopy() const override; + Expr* shallowCopy() const override; IterDomain* outX() const { return out_x_; diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index a59cd53c4159..a82b6df98eda 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -85,7 +85,7 @@ class TORCH_CUDA_CU_API Predicate final : public Val { Predicate* maybeWithReplacedExpr( const Expr* expr, - const Expr* new_expr); // TODO const + const Expr* new_expr); // TODO make this const PredicateType predicate_type() const { return ptype_; diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp index 989c00be81b7..4c7a73f5bfdf 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp @@ -20,7 +20,7 @@ namespace cuda { namespace { -class ConditionalFromPredicateModifier : public kir::IrVisitor { +class ConditionalFromPredicateModifier : public kir::ExprMutator { public: ConditionalFromPredicateModifier() = delete; @@ -32,11 +32,11 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { private: ConditionalFromPredicateModifier(const std::vector& exprs) { FUSER_PERF_SCOPE( - "GpuLower::Lower::ConditionalFromPredicateModifier::process"); - kir::IrVisitor::handle(exprs); + "ConditionalFromPredicateModifier::ConditionalFromPredicateModifier"); + traverseAndInsert(exprs); } - using kir::IrVisitor::handle; + using kir::ExprMutator::handle; void handle(Expr* expr) final { if (expr != nullptr && expr->predicate() != nullptr) { @@ -72,7 +72,7 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { TORCH_INTERNAL_ASSERT(conditional != nullptr); expr->predicate()->setValue(conditional); TORCH_INTERNAL_ASSERT(expr->predicate()->value() != nullptr); - setWritePredicate(expr, conditional); + setWritePredicate(expr); } // Note: [Predicate Inversion for CpAsync] @@ -101,7 +101,7 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { invertPredicateForGmemToSharedMemInitialize(expr); } - kir::IrVisitor::handle(expr); + kir::ExprMutator::handle(expr); } // Invert the predicate of given expr. @@ -123,7 +123,7 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { ir_utils::isCpAsyncInit(maybe_init.value()); } - void setWritePredicate(Expr* expr, Bool* read_cond) { + void setWritePredicate(Expr* expr) { if (expr->writePredicate() != nullptr) { auto write_cond = generateConditional(expr->writePredicate()); if (write_cond) { @@ -131,7 +131,7 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { } else { // If generateConditional returns null, it means no specific // predicate needs to be used. - expr->setWritePredicate(nullptr); + registerReplace(expr, expr->withWritePredicate(nullptr)); } } } @@ -150,7 +150,7 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { ite->predicate()->setValue(conditional); TORCH_INTERNAL_ASSERT(ite->predicate()->value() != nullptr); } - kir::IrVisitor::handle(ite); + kir::ExprMutator::handle(ite); } // Generate conditional according to PredicateType From c13489dca830ad350fff427de6ec5cd4ed0118ab Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 3 Oct 2022 18:04:32 -0700 Subject: [PATCH 11/15] maybeWithReplacedExpr const --- torch/csrc/jit/codegen/cuda/kernel_ir.cpp | 15 ++++++++++----- torch/csrc/jit/codegen/cuda/kernel_ir.h | 7 +++---- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index e5aec63870b0..33c45fc45652 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -52,11 +52,16 @@ Predicate::Predicate(IrBuilderPasskey passkey, Bool* value) Predicate* Predicate::maybeWithReplacedExpr( const Expr* expr, - const Expr* new_expr) { - if (expr_ == expr) { - return IrBuilder::create(ptype_, new_expr, thread_pred_); - } - return this; + const Expr* new_expr) const { + // Just create an instance with dummy data, real data will be filled below + auto result = IrBuilder::create(PredicateType::Vectorize); + + result->ptype_ = ptype_; + result->expr_ = (expr_ == expr ? new_expr : expr_); + result->thread_pred_ = thread_pred_; + result->unrolled_loop_ = unrolled_loop_; + result->value_ = value_; + return result; } TensorIndex::TensorIndex( diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index a82b6df98eda..c58fc9a6d1cb 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -83,9 +83,8 @@ class TORCH_CUDA_CU_API Predicate final : public Val { explicit Predicate(IrBuilderPasskey passkey, Bool* value); - Predicate* maybeWithReplacedExpr( - const Expr* expr, - const Expr* new_expr); // TODO make this const + Predicate* maybeWithReplacedExpr(const Expr* expr, const Expr* new_expr) + const; PredicateType predicate_type() const { return ptype_; @@ -98,7 +97,7 @@ class TORCH_CUDA_CU_API Predicate final : public Val { return expr_; } - Bool* thread_pred() { + Bool* thread_pred() const { TORCH_INTERNAL_ASSERT( ptype_ == PredicateType::Inline || ptype_ == PredicateType::Misaligned || ptype_ == PredicateType::Shift || From a799398385e4246e1d52e97787a90bf7dce44382 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 4 Oct 2022 20:52:06 -0700 Subject: [PATCH 12/15] comments --- torch/csrc/jit/codegen/cuda/ir_base_nodes.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index 781cf2145813..706dfb5ea7e3 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -426,6 +426,8 @@ class TORCH_CUDA_CU_API Expr : public Statement { Expr(const Expr* src, IrCloner* ir_cloner); + // Creates a new instance of the expression with all its field copied. + // Note that unlike IrCloner, this function only do a shallow copy virtual Expr* shallowCopy() const = 0; c10::optional getExprType() const override { @@ -468,12 +470,15 @@ class TORCH_CUDA_CU_API Expr : public Statement { // TODO: Protect based on being in kernel container kir::Predicate* predicate() const; + // Creates a shallow copy the expression with the given predicate attached. // TODO: Protect based on being in kernel container Expr* withPredicate(kir::Predicate* predicate); // TODO: Protect based on being in kernel container kir::Predicate* writePredicate() const; + // Creates a shallow copy the expression with the given write-predicate + // attached. // TODO: Protect based on being in kernel container Expr* withWritePredicate(kir::Predicate* write_predicate); From 5eac017a5b70d172890363da59e99bc531ba3758 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 4 Oct 2022 22:57:03 -0700 Subject: [PATCH 13/15] Remove maybeWithReplacedExpr --- torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp | 6 ------ torch/csrc/jit/codegen/cuda/kernel_ir.cpp | 14 ------------- torch/csrc/jit/codegen/cuda/kernel_ir.h | 3 --- torch/csrc/jit/codegen/cuda/lower_unroll.cpp | 21 +++++++------------ 4 files changed, 8 insertions(+), 36 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp index 2f04a6fab1ac..b594bc7a06d7 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp @@ -343,9 +343,6 @@ void Expr::setPredicate(kir::Predicate* predicate) { Expr* Expr::withPredicate(kir::Predicate* predicate) { auto result = shallowCopy(); - if (predicate != nullptr) { - predicate = predicate->maybeWithReplacedExpr(this, result); - } result->setPredicate(predicate); return result; } @@ -364,9 +361,6 @@ void Expr::setWritePredicate(kir::Predicate* write_predicate) { Expr* Expr::withWritePredicate(kir::Predicate* predicate) { auto result = shallowCopy(); - if (predicate != nullptr) { - predicate = predicate->maybeWithReplacedExpr(this, result); - } result->setWritePredicate(predicate); return result; } diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index 33c45fc45652..8c8cb29dbb78 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -50,20 +50,6 @@ Predicate::Predicate(IrBuilderPasskey passkey, Bool* value) TORCH_INTERNAL_ASSERT(value != nullptr); } -Predicate* Predicate::maybeWithReplacedExpr( - const Expr* expr, - const Expr* new_expr) const { - // Just create an instance with dummy data, real data will be filled below - auto result = IrBuilder::create(PredicateType::Vectorize); - - result->ptype_ = ptype_; - result->expr_ = (expr_ == expr ? new_expr : expr_); - result->thread_pred_ = thread_pred_; - result->unrolled_loop_ = unrolled_loop_; - result->value_ = value_; - return result; -} - TensorIndex::TensorIndex( IrBuilderPasskey passkey, const TensorView* view, diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index c58fc9a6d1cb..cd44e8d8e21b 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -83,9 +83,6 @@ class TORCH_CUDA_CU_API Predicate final : public Val { explicit Predicate(IrBuilderPasskey passkey, Bool* value); - Predicate* maybeWithReplacedExpr(const Expr* expr, const Expr* new_expr) - const; - PredicateType predicate_type() const { return ptype_; } diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp index 78409ba49cc8..97a544da9f64 100644 --- a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp @@ -93,7 +93,7 @@ void UnrollPass::handle(Expr* expr) { // taken care by its own function. if (GpuLower::current()->haloInfo()->needsShiftPredicate(expr)) { expr_with_predicate = ShiftPredicateInserter::insert( - expr_with_predicate, for_loops_, thread_pred, unswitched_loop_); + expr, for_loops_, thread_pred, unswitched_loop_); if (expr_with_predicate != expr) { registerReplace(expr, expr_with_predicate, &for_loops_.back()->body()); } @@ -101,25 +101,21 @@ void UnrollPass::handle(Expr* expr) { } // Reduction may need a separate predicate for writes. - if (!isReductionInitExpr(expr_with_predicate) && - out_tv->domain()->hasReduction()) { + if (!isReductionInitExpr(expr) && out_tv->domain()->hasReduction()) { const auto write_pred = unswitched_loop_ ? thread_pred_expr : IrBuilder::create( - PredicateType::ReductionWrite, - expr_with_predicate, - thread_pred); + PredicateType::ReductionWrite, expr, thread_pred); expr_with_predicate = expr_with_predicate->withWritePredicate(write_pred); } // For expr calling a device func with block sync, don't create // if-then-else but pass the predicate to the device func - if (lower_utils::hasBlockSync( - expr_with_predicate, GpuLower::current()->threadPredMap())) { + if (lower_utils::hasBlockSync(expr, GpuLower::current()->threadPredMap())) { const auto pred = unswitched_loop_ ? thread_pred_expr : IrBuilder::create( - PredicateType::Inline, expr_with_predicate, thread_pred); + PredicateType::Inline, expr, thread_pred); expr_with_predicate = expr_with_predicate->withPredicate(pred); registerReplace(expr, expr_with_predicate, &for_loops_.back()->body()); return; @@ -137,10 +133,9 @@ void UnrollPass::handle(Expr* expr) { } if (pred == nullptr) { - pred = unswitched_loop_ - ? thread_pred_expr - : IrBuilder::create( - PredicateType::Inline, expr_with_predicate, thread_pred); + pred = unswitched_loop_ ? thread_pred_expr + : IrBuilder::create( + PredicateType::Inline, expr, thread_pred); } // If we need a predicate, put expr inside an if then else From 064aec347b32f79ac48bc9586e17e09310acb4d9 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 4 Oct 2022 23:18:21 -0700 Subject: [PATCH 14/15] private: --- torch/csrc/jit/codegen/cuda/lower_unroll.h | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.h b/torch/csrc/jit/codegen/cuda/lower_unroll.h index 14280def7643..786e45115ba6 100644 --- a/torch/csrc/jit/codegen/cuda/lower_unroll.h +++ b/torch/csrc/jit/codegen/cuda/lower_unroll.h @@ -77,6 +77,7 @@ class TORCH_CUDA_CU_API UnrollPass : kir::ExprMutator { void handle(Expr* expr) final; + private: // We will track which loops in the incoming IR will be replaced and by what std::unordered_map expr_replacement_map_; From d28dbafad2172bcd370ca9af055c4cc36b69fc63 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 5 Oct 2022 09:45:29 -0700 Subject: [PATCH 15/15] copyPredicatesFrom --- torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp | 7 ++ torch/csrc/jit/codegen/cuda/ir_base_nodes.h | 2 + torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 115 ++++-------------- torch/csrc/jit/codegen/cuda/kernel_ir.cpp | 45 +++---- 4 files changed, 47 insertions(+), 122 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp index b594bc7a06d7..ff00f659da63 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp @@ -365,6 +365,13 @@ Expr* Expr::withWritePredicate(kir::Predicate* predicate) { return result; } +void Expr::copyPredicatesFrom(const Expr* expr) { + if (container()->isA()) { + predicate_ = expr->predicate_; + write_predicate_ = expr->write_predicate_; + } +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index 706dfb5ea7e3..dadabe167ebf 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -489,6 +489,8 @@ class TORCH_CUDA_CU_API Expr : public Statement { // TODO: Protect based on being in kernel container void setWritePredicate(kir::Predicate* write_predicate); + void copyPredicatesFrom(const Expr* expr); + // TODO: Add Fusion passkey void addInput(Val* input) { TORCH_INTERNAL_ASSERT(input != nullptr); diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 2dec991dd9a5..a2e030e68e85 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -202,10 +202,7 @@ FullOp::FullOp(const FullOp* src, IrCloner* ir_cloner) Expr* FullOp::shallowCopy() const { auto result = IrBuilder::create(output(0), fill_value_, dtype_); - if (container()->isA()) { - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - } + result->copyPredicatesFrom(this); return result; } @@ -254,10 +251,7 @@ ARangeOp::ARangeOp(const ARangeOp* src, IrCloner* ir_cloner) Expr* ARangeOp::shallowCopy() const { auto result = IrBuilder::create( output(0), start_, end_, step_, dtype_, linear_index_); - if (container()->isA()) { - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - } + result->copyPredicatesFrom(this); return result; } @@ -319,10 +313,7 @@ EyeOp::EyeOp(const EyeOp* src, IrCloner* ir_cloner) Expr* EyeOp::shallowCopy() const { auto result = IrBuilder::create(output(0), dtype_, index1_, index2_); - if (container()->isA()) { - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - } + result->copyPredicatesFrom(this); return result; } @@ -374,10 +365,7 @@ UnaryOp::UnaryOp(const UnaryOp* src, IrCloner* ir_cloner) Expr* UnaryOp::shallowCopy() const { auto result = IrBuilder::create(unary_op_type_, out_, in_); - if (container()->isA()) { - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - } + result->copyPredicatesFrom(this); return result; } @@ -420,10 +408,7 @@ BinaryOp::BinaryOp(const BinaryOp* src, IrCloner* ir_cloner) Expr* BinaryOp::shallowCopy() const { auto result = IrBuilder::create(binary_op_type_, out_, lhs_, rhs_); - if (container()->isA()) { - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - } + result->copyPredicatesFrom(this); return result; } @@ -471,10 +456,7 @@ TernaryOp::TernaryOp(const TernaryOp* src, IrCloner* ir_cloner) Expr* TernaryOp::shallowCopy() const { auto result = IrBuilder::create(ternary_op_type_, out_, in1_, in2_, in3_); - if (container()->isA()) { - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - } + result->copyPredicatesFrom(this); return result; } @@ -531,10 +513,7 @@ RNGOp::RNGOp(const RNGOp* src, IrCloner* ir_cloner) Expr* RNGOp::shallowCopy() const { auto result = IrBuilder::create( rng_op_type_, output(0), dtype_, parameters_, rng_offset_, philox_index_); - if (container()->isA()) { - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - } + result->copyPredicatesFrom(this); return result; } @@ -651,10 +630,7 @@ BroadcastOp::BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner) Expr* BroadcastOp::shallowCopy() const { auto result = IrBuilder::create(out_, in_, is_broadcast_dims_); - if (container()->isA()) { - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - } + result->copyPredicatesFrom(this); return result; } @@ -723,10 +699,7 @@ ReductionOp::ReductionOp(const ReductionOp* src, IrCloner* ir_cloner) Expr* ReductionOp::shallowCopy() const { auto result = IrBuilder::create( reduction_op_type_, init_, out_, in_, is_allreduce_, etype()); - if (container()->isA()) { - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - } + result->copyPredicatesFrom(this); return result; } @@ -782,10 +755,7 @@ Expr* GroupedReductionOp::shallowCopy() const { inputs(), is_allreduce_, etype()); - if (container()->isA()) { - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - } + result->copyPredicatesFrom(this); return result; } @@ -966,10 +936,7 @@ WelfordOp::WelfordOp(const WelfordOp* src, IrCloner* ir_cloner) Expr* WelfordOp::shallowCopy() const { auto result = IrBuilder::create(output_, input_, init_, is_allreduce_); - if (container()->isA()) { - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - } + result->copyPredicatesFrom(this); return result; } @@ -1125,10 +1092,7 @@ GroupedWelfordOp::GroupedWelfordOp( Expr* GroupedWelfordOp::shallowCopy() const { auto result = IrBuilder::create( output_vals_, input_vals_, init_vals_, is_allreduce_, etype()); - if (container()->isA()) { - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - } + result->copyPredicatesFrom(this); return result; } @@ -1229,10 +1193,7 @@ MmaOp::MmaOp(const MmaOp* src, IrCloner* ir_cloner) Expr* MmaOp::shallowCopy() const { auto result = IrBuilder::create(out_, in_a_, in_b_, init_); result->options_ = options_; - if (container()->isA()) { - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - } + result->copyPredicatesFrom(this); return result; } @@ -1294,10 +1255,7 @@ TransposeOp::TransposeOp(const TransposeOp* src, IrCloner* ir_cloner) Expr* TransposeOp::shallowCopy() const { auto result = IrBuilder::create(out_, in_, new2old_); - if (container()->isA()) { - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - } + result->copyPredicatesFrom(this); return result; } @@ -1342,10 +1300,7 @@ ExpandOp::ExpandOp(const ExpandOp* src, IrCloner* ir_cloner) Expr* ExpandOp::shallowCopy() const { auto result = IrBuilder::create(out_, in_, expanded_extents_); - if (container()->isA()) { - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - } + result->copyPredicatesFrom(this); return result; } @@ -1398,10 +1353,7 @@ ShiftOp::ShiftOp(const ShiftOp* src, IrCloner* ir_cloner) Expr* ShiftOp::shallowCopy() const { auto result = IrBuilder::create(out_, in_, offsets_, pad_width_); - if (container()->isA()) { - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - } + result->copyPredicatesFrom(this); return result; } @@ -1470,10 +1422,7 @@ GatherOp::GatherOp(const GatherOp* src, IrCloner* ir_cloner) Expr* GatherOp::shallowCopy() const { auto result = IrBuilder::create(out_, in_, window_shape_, pad_width_); - if (container()->isA()) { - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - } + result->copyPredicatesFrom(this); return result; } @@ -1525,10 +1474,7 @@ ViewAsScalar::ViewAsScalar(const ViewAsScalar* src, IrCloner* ir_cloner) Expr* ViewAsScalar::shallowCopy() const { auto result = IrBuilder::create(out_, in_, vector_id_, index_); - if (container()->isA()) { - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - } + result->copyPredicatesFrom(this); return result; } @@ -1545,10 +1491,7 @@ ViewOp::ViewOp(const ViewOp* src, IrCloner* ir_cloner) Expr* ViewOp::shallowCopy() const { auto result = IrBuilder::create(out_, in_); - if (container()->isA()) { - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - } + result->copyPredicatesFrom(this); return result; } @@ -1573,10 +1516,7 @@ LoadStoreOp::LoadStoreOp(const LoadStoreOp* src, IrCloner* ir_cloner) Expr* LoadStoreOp::shallowCopy() const { auto result = IrBuilder::create(load_store_type_, out_, in_); - if (container()->isA()) { - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - } + result->copyPredicatesFrom(this); return result; } @@ -2692,10 +2632,7 @@ Split::Split(const Split* src, IrCloner* ir_cloner) Expr* Split::shallowCopy() const { auto result = IrBuilder::create( outer_, inner_, in_, factor_, inner_split_, start_offset_, stop_offset_); - if (container()->isA()) { - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - } + result->copyPredicatesFrom(this); return result; } @@ -2746,10 +2683,7 @@ Merge::Merge(const Merge* src, IrCloner* ir_cloner) Expr* Merge::shallowCopy() const { auto result = IrBuilder::create(out_, outer_, inner_); - if (container()->isA()) { - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - } + result->copyPredicatesFrom(this); return result; } @@ -2787,10 +2721,7 @@ Swizzle2D::Swizzle2D( Expr* Swizzle2D::shallowCopy() const { auto result = IrBuilder::create( out_x_, out_y_, in_x_, in_y_, swizzle_type_, swizzle_mode_); - if (container()->isA()) { - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - } + result->copyPredicatesFrom(this); return result; } diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index 8c8cb29dbb78..7e69f0307a7a 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -96,8 +96,7 @@ BlockSync::BlockSync(IrBuilderPasskey passkey, bool war_sync) Expr* BlockSync::shallowCopy() const { auto result = IrBuilder::create(war_sync_); - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); + result->copyPredicatesFrom(this); return result; } @@ -111,8 +110,7 @@ GridSync::GridSync( Expr* GridSync::shallowCopy() const { auto result = IrBuilder::create(sync_dims_, sync_buffer_); - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); + result->copyPredicatesFrom(this); return result; } @@ -125,8 +123,7 @@ CpAsyncWait::CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages) Expr* CpAsyncWait::shallowCopy() const { auto result = IrBuilder::create(keep_stages_); - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); + result->copyPredicatesFrom(this); return result; } @@ -139,8 +136,7 @@ CpAsyncCommit::CpAsyncCommit(IrBuilderPasskey passkey) Expr* CpAsyncCommit::shallowCopy() const { auto result = IrBuilder::create(); - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); + result->copyPredicatesFrom(this); return result; } @@ -153,8 +149,7 @@ InitMagicZero::InitMagicZero(IrBuilderPasskey passkey) Expr* InitMagicZero::shallowCopy() const { auto result = IrBuilder::create(); - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); + result->copyPredicatesFrom(this); return result; } @@ -167,8 +162,7 @@ UpdateMagicZero::UpdateMagicZero(IrBuilderPasskey passkey) Expr* UpdateMagicZero::shallowCopy() const { auto result = IrBuilder::create(); - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); + result->copyPredicatesFrom(this); return result; } @@ -200,8 +194,7 @@ PairSelect::PairSelect( Expr* PairSelect::shallowCopy() const { auto result = IrBuilder::create(out_, in_, selection_); - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); + result->copyPredicatesFrom(this); return result; } @@ -233,8 +226,7 @@ Swizzle2DInt::Swizzle2DInt( Expr* Swizzle2DInt::shallowCopy() const { auto result = IrBuilder::create( out_, in_x_, in_y_, extent_x_, extent_y_, swizzle_type_); - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); + result->copyPredicatesFrom(this); return result; } @@ -385,8 +377,7 @@ Expr* ForLoop::shallowCopy() const { unroll_required_, double_buffer_loop_stage_); result->body_ = body_; - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); + result->copyPredicatesFrom(this); return result; } @@ -580,8 +571,7 @@ Allocate::Allocate( Expr* Allocate::shallowCopy() const { auto result = IrBuilder::create(buffer_, memory_type_, shape_, zero_init_); - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); + result->copyPredicatesFrom(this); return result; } @@ -624,8 +614,7 @@ Expr* GridReduction::shallowCopy() const { entrance_index_, entrances_, isAllreduce()); - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); + result->copyPredicatesFrom(this); result->thread_predicate_ = thread_predicate_; return result; } @@ -672,8 +661,7 @@ Expr* GroupedGridReduction::shallowCopy() const { entrances_, buffer_stride_, isAllreduce()); - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); + result->copyPredicatesFrom(this); result->thread_predicate_ = thread_predicate_; return result; } @@ -695,8 +683,7 @@ GridBroadcast::GridBroadcast( Expr* GridBroadcast::shallowCopy() const { auto result = IrBuilder::create( broadcast_op_, broadcast_buffer_, sync_buffer_); - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); + result->copyPredicatesFrom(this); return result; } @@ -731,8 +718,7 @@ Expr* GridWelford::shallowCopy() const { sync_buffer_, entrance_index_, entrances_); - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); + result->copyPredicatesFrom(this); result->thread_predicate_ = thread_predicate_; return result; } @@ -776,8 +762,7 @@ Expr* GroupedGridWelford::shallowCopy() const { entrances_, buffer_stride_, isAllreduce()); - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); + result->copyPredicatesFrom(this); result->thread_predicate_ = thread_predicate_; return result; }