diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp index b29a8bc417cd..ff00f659da63 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp @@ -341,6 +341,12 @@ void Expr::setPredicate(kir::Predicate* predicate) { predicate_ = predicate; } +Expr* Expr::withPredicate(kir::Predicate* predicate) { + auto result = shallowCopy(); + result->setPredicate(predicate); + return result; +} + kir::Predicate* Expr::writePredicate() const { TORCH_INTERNAL_ASSERT( container()->isA(), "Function invalid for fusion."); @@ -353,6 +359,19 @@ void Expr::setWritePredicate(kir::Predicate* write_predicate) { write_predicate_ = write_predicate; } +Expr* Expr::withWritePredicate(kir::Predicate* predicate) { + auto result = shallowCopy(); + result->setWritePredicate(predicate); + return result; +} + +void Expr::copyPredicatesFrom(const Expr* expr) { + if (container()->isA()) { + predicate_ = expr->predicate_; + write_predicate_ = expr->write_predicate_; + } +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index 7d5ebad25282..dadabe167ebf 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -426,6 +426,10 @@ class TORCH_CUDA_CU_API Expr : public Statement { Expr(const Expr* src, IrCloner* ir_cloner); + // Creates a new instance of the expression with all its field copied. + // Note that unlike IrCloner, this function only do a shallow copy + virtual Expr* shallowCopy() const = 0; + c10::optional getExprType() const override { return etype_; } @@ -466,16 +470,27 @@ class TORCH_CUDA_CU_API Expr : public Statement { // TODO: Protect based on being in kernel container kir::Predicate* predicate() const; + // Creates a shallow copy the expression with the given predicate attached. // TODO: Protect based on being in kernel container - void setPredicate(kir::Predicate* predicate); + Expr* withPredicate(kir::Predicate* predicate); // TODO: Protect based on being in kernel container kir::Predicate* writePredicate() const; + // Creates a shallow copy the expression with the given write-predicate + // attached. // TODO: Protect based on being in kernel container - void setWritePredicate(kir::Predicate* write_predicate); + Expr* withWritePredicate(kir::Predicate* write_predicate); protected: + // TODO: Protect based on being in kernel container + void setPredicate(kir::Predicate* predicate); + + // TODO: Protect based on being in kernel container + void setWritePredicate(kir::Predicate* write_predicate); + + void copyPredicatesFrom(const Expr* expr); + // TODO: Add Fusion passkey void addInput(Val* input) { TORCH_INTERNAL_ASSERT(input != nullptr); diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index aa8793366a32..74c055c3faac 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -36,6 +36,8 @@ class TORCH_CUDA_CU_API FullOp : public Expr { FullOp(const FullOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + bool sameAs(const Statement* other) const override; DataType dtype() const { @@ -64,6 +66,8 @@ class TORCH_CUDA_CU_API ARangeOp : public Expr { ARangeOp(const ARangeOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + bool sameAs(const Statement* other) const override; DataType dtype() const { @@ -127,6 +131,8 @@ class TORCH_CUDA_CU_API EyeOp : public Expr { EyeOp(const EyeOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + bool sameAs(const Statement* other) const override; DataType dtype() const { @@ -172,6 +178,8 @@ class TORCH_CUDA_CU_API UnaryOp : public Expr { UnaryOp(const UnaryOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + Val* out() const { return out_; } @@ -201,6 +209,8 @@ class TORCH_CUDA_CU_API BinaryOp : public Expr { BinaryOp(const BinaryOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + Val* out() const { return out_; } @@ -239,6 +249,8 @@ class TORCH_CUDA_CU_API RNGOp : public Expr { RNGOp(const RNGOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + RNGOpType getRNGOpType() const { return rng_op_type_; } @@ -298,6 +310,8 @@ class TORCH_CUDA_CU_API BroadcastOp : public Expr { BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + Val* out() const { return out_; } @@ -346,6 +360,8 @@ class TORCH_CUDA_CU_API ReductionOp : public Expr { ReductionOp(const ReductionOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + Val* out() const { return out_; } @@ -394,6 +410,8 @@ class TORCH_CUDA_CU_API GroupedReductionOp : public Expr { GroupedReductionOp(const GroupedReductionOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + //! Number of expressions grouped horizontally. It does not reflect //! iteration grouping. size_t numExprs() const { @@ -580,6 +598,8 @@ class TORCH_CUDA_CU_API WelfordOp : public Expr { WelfordOp(const WelfordOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + Val* out() const { return output().avg(); } @@ -675,6 +695,8 @@ class TORCH_CUDA_CU_API GroupedWelfordOp : public Expr { GroupedWelfordOp(const GroupedWelfordOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + //! Number of expressions grouped horizontally. It does not reflect //! iteration grouping. As horizontal grouping is not supported, //! this always returns 1. @@ -798,6 +820,8 @@ class TORCH_CUDA_CU_API MmaOp : public Expr { MmaOp(const MmaOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + Val* out() const { return out_; } @@ -856,6 +880,8 @@ class TORCH_CUDA_CU_API TransposeOp : public Expr { TransposeOp(const TransposeOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + TensorView* out() const { return out_; } @@ -886,6 +912,8 @@ class TORCH_CUDA_CU_API ExpandOp : public Expr { ExpandOp(const ExpandOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + TensorView* out() const { return out_; } @@ -916,6 +944,8 @@ class TORCH_CUDA_CU_API TernaryOp : public Expr { TernaryOp(const TernaryOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + Val* out() const { return out_; } @@ -959,6 +989,8 @@ class TORCH_CUDA_CU_API ShiftOp : public Expr { ShiftOp(const ShiftOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + Val* out() const { return out_; } @@ -1008,6 +1040,8 @@ class TORCH_CUDA_CU_API GatherOp : public Expr { GatherOp(const GatherOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + Val* out() const { return out_; } @@ -1054,6 +1088,8 @@ class TORCH_CUDA_CU_API ViewAsScalar : public Expr { ViewAsScalar(const ViewAsScalar* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + Val* out() const { return out_; } @@ -1087,6 +1123,8 @@ class TORCH_CUDA_CU_API ViewOp : public Expr { ViewOp(const ViewOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + TensorView* out() const { return out_; } @@ -1112,6 +1150,8 @@ class TORCH_CUDA_CU_API LoadStoreOp : public Expr { LoadStoreOp(const LoadStoreOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + Val* out() const { return out_; } @@ -1691,6 +1731,8 @@ class TORCH_CUDA_CU_API Split : public Expr { Split(const Split* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + IterDomain* outer() const { return outer_; } @@ -1751,6 +1793,8 @@ class TORCH_CUDA_CU_API Merge : public Expr { Merge(const Merge* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + IterDomain* out() const { return out_; } @@ -1783,6 +1827,8 @@ class TORCH_CUDA_CU_API Swizzle2D : public Expr { Swizzle2D(const Swizzle2D* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + IterDomain* outX() const { return out_x_; } diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 3319bf28a18a..1c3dae9a72d9 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -200,6 +200,12 @@ FullOp::FullOp(const FullOp* src, IrCloner* ir_cloner) dtype_(src->dtype()), fill_value_(ir_cloner->clone(src->fill_value_)) {} +Expr* FullOp::shallowCopy() const { + auto result = IrBuilder::create(output(0), fill_value_, dtype_); + result->copyPredicatesFrom(this); + return result; +} + bool FullOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -242,6 +248,43 @@ ARangeOp::ARangeOp(const ARangeOp* src, IrCloner* ir_cloner) step_(ir_cloner->clone(src->step_)), linear_index_(ir_cloner->clone(src->linear_index_)) {} +Expr* ARangeOp::shallowCopy() const { + auto result = IrBuilder::create( + output(0), start_, end_, step_, dtype_, linear_index_); + result->copyPredicatesFrom(this); + return result; +} + +bool ARangeOp::sameAs(const Statement* other) const { + if (this == other) { + return true; + } + if (!other->isA()) { + return false; + } + const auto other_op = other->as(); + if (dtype_ != other_op->dtype_) { + return false; + } + if (!start_->sameAs(other_op->start_)) { + return false; + } + if (!end_->sameAs(other_op->end_)) { + return false; + } + if (!step_->sameAs(other_op->step_)) { + return false; + } + if ((linear_index_ == nullptr) != (other_op->linear_index_ == nullptr)) { + return false; + } + if ((linear_index_ != nullptr) && + !linear_index_->sameAs(other_op->linear_index_)) { + return false; + } + return Expr::sameAs(other); +} + EyeOp::EyeOp( IrBuilderPasskey passkey, Val* out, @@ -268,6 +311,12 @@ EyeOp::EyeOp(const EyeOp* src, IrCloner* ir_cloner) index1_(ir_cloner->clone(src->index1_)), index2_(ir_cloner->clone(src->index2_)) {} +Expr* EyeOp::shallowCopy() const { + auto result = IrBuilder::create(output(0), dtype_, index1_, index2_); + result->copyPredicatesFrom(this); + return result; +} + bool EyeOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -294,36 +343,6 @@ bool EyeOp::sameAs(const Statement* other) const { return Expr::sameAs(other); } -bool ARangeOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - if (dtype_ != other_op->dtype_) { - return false; - } - if (!start_->sameAs(other_op->start_)) { - return false; - } - if (!end_->sameAs(other_op->end_)) { - return false; - } - if (!step_->sameAs(other_op->step_)) { - return false; - } - if ((linear_index_ == nullptr) != (other_op->linear_index_ == nullptr)) { - return false; - } - if ((linear_index_ != nullptr) && - !linear_index_->sameAs(other_op->linear_index_)) { - return false; - } - return Expr::sameAs(other); -} - UnaryOp::UnaryOp( IrBuilderPasskey passkey, UnaryOpType type, @@ -344,6 +363,12 @@ UnaryOp::UnaryOp(const UnaryOp* src, IrCloner* ir_cloner) out_(ir_cloner->clone(src->out_)), in_(ir_cloner->clone(src->in_)) {} +Expr* UnaryOp::shallowCopy() const { + auto result = IrBuilder::create(unary_op_type_, out_, in_); + result->copyPredicatesFrom(this); + return result; +} + bool UnaryOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -381,6 +406,12 @@ BinaryOp::BinaryOp(const BinaryOp* src, IrCloner* ir_cloner) lhs_(ir_cloner->clone(src->lhs_)), rhs_(ir_cloner->clone(src->rhs_)) {} +Expr* BinaryOp::shallowCopy() const { + auto result = IrBuilder::create(binary_op_type_, out_, lhs_, rhs_); + result->copyPredicatesFrom(this); + return result; +} + bool BinaryOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -422,6 +453,13 @@ TernaryOp::TernaryOp(const TernaryOp* src, IrCloner* ir_cloner) in2_(ir_cloner->clone(src->in2_)), in3_(ir_cloner->clone(src->in3_)) {} +Expr* TernaryOp::shallowCopy() const { + auto result = + IrBuilder::create(ternary_op_type_, out_, in1_, in2_, in3_); + result->copyPredicatesFrom(this); + return result; +} + bool TernaryOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -472,6 +510,13 @@ RNGOp::RNGOp(const RNGOp* src, IrCloner* ir_cloner) rng_offset_(src->rng_offset_), philox_index_(ir_cloner->clone(src->philox_index_)) {} +Expr* RNGOp::shallowCopy() const { + auto result = IrBuilder::create( + rng_op_type_, output(0), dtype_, parameters_, rng_offset_, philox_index_); + result->copyPredicatesFrom(this); + return result; +} + bool RNGOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -583,6 +628,12 @@ BroadcastOp::BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner) in_(ir_cloner->clone(src->in_)), is_broadcast_dims_(src->is_broadcast_dims_) {} +Expr* BroadcastOp::shallowCopy() const { + auto result = IrBuilder::create(out_, in_, is_broadcast_dims_); + result->copyPredicatesFrom(this); + return result; +} + bool BroadcastOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -637,6 +688,36 @@ ReductionOp::ReductionOp( addInput(in); } +ReductionOp::ReductionOp(const ReductionOp* src, IrCloner* ir_cloner) + : Expr(src, ir_cloner), + reduction_op_type_(src->reduction_op_type_), + init_(ir_cloner->clone(src->init_)), + out_(ir_cloner->clone(src->out_)), + in_(ir_cloner->clone(src->in_)), + is_allreduce_(src->is_allreduce_) {} + +Expr* ReductionOp::shallowCopy() const { + auto result = IrBuilder::create( + reduction_op_type_, init_, out_, in_, is_allreduce_, etype()); + result->copyPredicatesFrom(this); + return result; +} + +bool ReductionOp::sameAs(const Statement* other) const { + if (this == other) { + return true; + } + if (!other->isA()) { + return false; + } + const auto other_op = other->as(); + // Note that init is not part of input vals, so it must be checked separately. + return ( + Expr::sameAs(other) && + getReductionOpType() == other_op->getReductionOpType() && + init()->sameAs(other_op->init())); +} + GroupedReductionOp::GroupedReductionOp( IrBuilderPasskey passkey, std::vector reduction_op_types, @@ -666,6 +747,18 @@ GroupedReductionOp::GroupedReductionOp( init_vals_(ir_cloner->clone(src->init_vals_)), is_allreduce_(src->is_allreduce_) {} +Expr* GroupedReductionOp::shallowCopy() const { + auto result = IrBuilder::create( + reduction_op_types_, + init_vals_, + outputs(), + inputs(), + is_allreduce_, + etype()); + result->copyPredicatesFrom(this); + return result; +} + int GroupedReductionOp::getExprIndexOfOutput(Val* output_val) const { auto it = std::find(outputs().begin(), outputs().end(), output_val); if (it != outputs().end()) { @@ -840,6 +933,13 @@ WelfordOp::WelfordOp(const WelfordOp* src, IrCloner* ir_cloner) init_(src->init_.clone(ir_cloner)), is_allreduce_(src->is_allreduce_) {} +Expr* WelfordOp::shallowCopy() const { + auto result = + IrBuilder::create(output_, input_, init_, is_allreduce_); + result->copyPredicatesFrom(this); + return result; +} + Val* WelfordOp::getInitValOfOutput(Val* output_val) const { auto val_name = output().getNameOf(output_val); @@ -989,6 +1089,13 @@ GroupedWelfordOp::GroupedWelfordOp( init_vals_(WelfordTriplet::clone(src->init_vals_, ir_cloner)), is_allreduce_(src->is_allreduce_) {} +Expr* GroupedWelfordOp::shallowCopy() const { + auto result = IrBuilder::create( + output_vals_, input_vals_, init_vals_, is_allreduce_, etype()); + result->copyPredicatesFrom(this); + return result; +} + bool GroupedWelfordOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -1083,6 +1190,13 @@ MmaOp::MmaOp(const MmaOp* src, IrCloner* ir_cloner) init_(ir_cloner->clone(src->init_)), options_(src->options_) {} +Expr* MmaOp::shallowCopy() const { + auto result = IrBuilder::create(out_, in_a_, in_b_, init_); + result->options_ = options_; + result->copyPredicatesFrom(this); + return result; +} + bool MmaOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -1095,29 +1209,6 @@ bool MmaOp::sameAs(const Statement* other) const { return false; } -ReductionOp::ReductionOp(const ReductionOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - reduction_op_type_(src->reduction_op_type_), - init_(ir_cloner->clone(src->init_)), - out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)), - is_allreduce_(src->is_allreduce_) {} - -bool ReductionOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - // Note that init is not part of input vals, so it must be checked separately. - return ( - Expr::sameAs(other) && - getReductionOpType() == other_op->getReductionOpType() && - init()->sameAs(other_op->init())); -} - TransposeOp::TransposeOp( IrBuilderPasskey passkey, TensorView* out, @@ -1162,6 +1253,12 @@ TransposeOp::TransposeOp(const TransposeOp* src, IrCloner* ir_cloner) in_(ir_cloner->clone(src->in_)), new2old_(src->new2old_) {} +Expr* TransposeOp::shallowCopy() const { + auto result = IrBuilder::create(out_, in_, new2old_); + result->copyPredicatesFrom(this); + return result; +} + std::vector TransposeOp::old2new() const { std::vector old2new(new2old_.size()); for (auto new_axis : c10::irange(new2old_.size())) { @@ -1201,6 +1298,12 @@ ExpandOp::ExpandOp(const ExpandOp* src, IrCloner* ir_cloner) } } +Expr* ExpandOp::shallowCopy() const { + auto result = IrBuilder::create(out_, in_, expanded_extents_); + result->copyPredicatesFrom(this); + return result; +} + ShiftOp::ShiftOp( IrBuilderPasskey passkey, Val* out, @@ -1248,6 +1351,12 @@ ShiftOp::ShiftOp(const ShiftOp* src, IrCloner* ir_cloner) offsets_(src->offsets_), pad_width_(src->pad_width_) {} +Expr* ShiftOp::shallowCopy() const { + auto result = IrBuilder::create(out_, in_, offsets_, pad_width_); + result->copyPredicatesFrom(this); + return result; +} + bool ShiftOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -1310,6 +1419,13 @@ GatherOp::GatherOp(const GatherOp* src, IrCloner* ir_cloner) window_shape_(src->window_shape_), pad_width_(src->pad_width_) {} +Expr* GatherOp::shallowCopy() const { + auto result = + IrBuilder::create(out_, in_, window_shape_, pad_width_); + result->copyPredicatesFrom(this); + return result; +} + bool GatherOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -1356,6 +1472,12 @@ ViewAsScalar::ViewAsScalar(const ViewAsScalar* src, IrCloner* ir_cloner) vector_id_(ir_cloner->clone(src->vector_id_)), index_(ir_cloner->clone(src->index_)) {} +Expr* ViewAsScalar::shallowCopy() const { + auto result = IrBuilder::create(out_, in_, vector_id_, index_); + result->copyPredicatesFrom(this); + return result; +} + ViewOp::ViewOp(IrBuilderPasskey passkey, TensorView* out, TensorView* in) : Expr(passkey, ExprType::ViewOp), out_(out), in_(in) { addOutput(out); @@ -1367,6 +1489,12 @@ ViewOp::ViewOp(const ViewOp* src, IrCloner* ir_cloner) out_(ir_cloner->clone(src->out_)), in_(ir_cloner->clone(src->in_)) {} +Expr* ViewOp::shallowCopy() const { + auto result = IrBuilder::create(out_, in_); + result->copyPredicatesFrom(this); + return result; +} + LoadStoreOp::LoadStoreOp( IrBuilderPasskey passkey, LoadStoreOpType op_type, @@ -1386,6 +1514,12 @@ LoadStoreOp::LoadStoreOp(const LoadStoreOp* src, IrCloner* ir_cloner) out_(ir_cloner->clone(src->out_)), in_(ir_cloner->clone(src->in_)) {} +Expr* LoadStoreOp::shallowCopy() const { + auto result = IrBuilder::create(load_store_type_, out_, in_); + result->copyPredicatesFrom(this); + return result; +} + IterDomainBuilder::IterDomainBuilder(Val* _start, Val* _extent) : start_(_start), extent_(_extent) { TORCH_INTERNAL_ASSERT( @@ -2496,6 +2630,13 @@ Split::Split(const Split* src, IrCloner* ir_cloner) start_offset_(ir_cloner->clone(src->start_offset_)), stop_offset_(ir_cloner->clone(src->stop_offset_)) {} +Expr* Split::shallowCopy() const { + auto result = IrBuilder::create( + outer_, inner_, in_, factor_, inner_split_, start_offset_, stop_offset_); + result->copyPredicatesFrom(this); + return result; +} + Val* Split::extent(Val* in_extent, Val* start_offset, Val* stop_offset) { TORCH_INTERNAL_ASSERT(in_extent != nullptr); @@ -2541,6 +2682,12 @@ Merge::Merge(const Merge* src, IrCloner* ir_cloner) outer_(ir_cloner->clone(src->outer_)), inner_(ir_cloner->clone(src->inner_)) {} +Expr* Merge::shallowCopy() const { + auto result = IrBuilder::create(out_, outer_, inner_); + result->copyPredicatesFrom(this); + return result; +} + bool Merge::sameAs(const Statement* other) const { if (this == other) { return true; @@ -2572,6 +2719,13 @@ Swizzle2D::Swizzle2D( addInput(in_y); } +Expr* Swizzle2D::shallowCopy() const { + auto result = IrBuilder::create( + out_x_, out_y_, in_x_, in_y_, swizzle_type_, swizzle_mode_); + result->copyPredicatesFrom(this); + return result; +} + bool Swizzle2D::sameAs(const Statement* other) const { if (this == other) { return true; diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.h b/torch/csrc/jit/codegen/cuda/ir_utils.h index ad9f67180f92..adfc64fc74ad 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.h +++ b/torch/csrc/jit/codegen/cuda/ir_utils.h @@ -325,7 +325,7 @@ TORCH_CUDA_CU_API std::vector getViewOps(Fusion*); template std::string toString(const T& nodes) { std::stringstream ss; - for (Statement* stmt : nodes) { + for (const Statement* stmt : nodes) { if (ss.tellp() != 0) { ss << ", "; } diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index 132b99b31c34..7e69f0307a7a 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -78,6 +78,15 @@ TensorIndex::TensorIndex( } } +Val* TensorIndex::index(int i) const { + TORCH_INTERNAL_ASSERT( + nDims() > 0, "Tried to get an index of a 0-dim TensorIndex"); + if (i < 0) + i += nDims(); + TORCH_INTERNAL_ASSERT(i >= 0 && i < int(nDims())); + return indices_[i]; +} + BlockSync::BlockSync(IrBuilderPasskey passkey, bool war_sync) : Expr(passkey, ExprType::BlockSync), war_sync_(war_sync) { TORCH_INTERNAL_ASSERT( @@ -85,6 +94,12 @@ BlockSync::BlockSync(IrBuilderPasskey passkey, bool war_sync) "IR type only valid for Kernel container."); } +Expr* BlockSync::shallowCopy() const { + auto result = IrBuilder::create(war_sync_); + result->copyPredicatesFrom(this); + return result; +} + GridSync::GridSync( IrBuilderPasskey passkey, ParallelTypeBitmap sync_dims, @@ -93,6 +108,12 @@ GridSync::GridSync( sync_dims_(sync_dims), sync_buffer_(sync_buffer) {} +Expr* GridSync::shallowCopy() const { + auto result = IrBuilder::create(sync_dims_, sync_buffer_); + result->copyPredicatesFrom(this); + return result; +} + CpAsyncWait::CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages) : Expr(passkey, ExprType::CpAsyncWait), keep_stages_(keep_stages) { TORCH_INTERNAL_ASSERT( @@ -100,6 +121,12 @@ CpAsyncWait::CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages) "IR type only valid for Kernel container."); } +Expr* CpAsyncWait::shallowCopy() const { + auto result = IrBuilder::create(keep_stages_); + result->copyPredicatesFrom(this); + return result; +} + CpAsyncCommit::CpAsyncCommit(IrBuilderPasskey passkey) : Expr(passkey, ExprType::CpAsyncCommit) { TORCH_INTERNAL_ASSERT( @@ -107,6 +134,12 @@ CpAsyncCommit::CpAsyncCommit(IrBuilderPasskey passkey) "IR type only valid for Kernel container."); } +Expr* CpAsyncCommit::shallowCopy() const { + auto result = IrBuilder::create(); + result->copyPredicatesFrom(this); + return result; +} + InitMagicZero::InitMagicZero(IrBuilderPasskey passkey) : Expr(passkey, ExprType::InitMagicZero) { TORCH_INTERNAL_ASSERT( @@ -114,6 +147,12 @@ InitMagicZero::InitMagicZero(IrBuilderPasskey passkey) "IR type only valid for Kernel container."); } +Expr* InitMagicZero::shallowCopy() const { + auto result = IrBuilder::create(); + result->copyPredicatesFrom(this); + return result; +} + UpdateMagicZero::UpdateMagicZero(IrBuilderPasskey passkey) : Expr(passkey, ExprType::UpdateMagicZero) { TORCH_INTERNAL_ASSERT( @@ -121,6 +160,12 @@ UpdateMagicZero::UpdateMagicZero(IrBuilderPasskey passkey) "IR type only valid for Kernel container."); } +Expr* UpdateMagicZero::shallowCopy() const { + auto result = IrBuilder::create(); + result->copyPredicatesFrom(this); + return result; +} + namespace { bool isIntegralScalar(const Val* val) { @@ -147,6 +192,12 @@ PairSelect::PairSelect( TORCH_INTERNAL_ASSERT(isIntegralScalar(out), "Integer only for this op"); } +Expr* PairSelect::shallowCopy() const { + auto result = IrBuilder::create(out_, in_, selection_); + result->copyPredicatesFrom(this); + return result; +} + Swizzle2DInt::Swizzle2DInt( IrBuilderPasskey passkey, IntPair* out, @@ -172,6 +223,13 @@ Swizzle2DInt::Swizzle2DInt( addInput(extent_y); } +Expr* Swizzle2DInt::shallowCopy() const { + auto result = IrBuilder::create( + out_, in_x_, in_y_, extent_x_, extent_y_, swizzle_type_); + result->copyPredicatesFrom(this); + return result; +} + void Scope::insert(std::vector::const_iterator pos, Expr* expr) { exprs_.insert(pos, expr); } @@ -307,6 +365,22 @@ ForLoop::ForLoop(IrBuilderPasskey passkey, const ForLoop* other) "IR type only valid for Kernel container."); } +Expr* ForLoop::shallowCopy() const { + auto result = IrBuilder::create( + iter_domain_, + index_, + start_, + stop_, + step_, + vectorize_, + vectorize_shift_, + unroll_required_, + double_buffer_loop_stage_); + result->body_ = body_; + result->copyPredicatesFrom(this); + return result; +} + bool ForLoop::isUnrollable() const { // Start and stop must be constant, must not be a broadcast // dimension, cannot be bound to a parallel dimension, must not be @@ -426,13 +500,12 @@ IfThenElse::IfThenElse(IrBuilderPasskey passkey, Predicate* cond) addInput(cond); } -Val* TensorIndex::index(int i) const { - TORCH_INTERNAL_ASSERT( - nDims() > 0, "Tried to get an index of a 0-dim TensorIndex"); - if (i < 0) - i += nDims(); - TORCH_INTERNAL_ASSERT(i >= 0 && i < int(nDims())); - return indices_[i]; +Expr* IfThenElse::shallowCopy() const { + auto result = IrBuilder::create(predicate()); + result->then_body_ = then_body_; + result->else_body_ = else_body_; + result->setWritePredicate(writePredicate()); + return result; } Allocate::Allocate( @@ -495,6 +568,13 @@ Allocate::Allocate( "IR type only valid for Kernel container."); } +Expr* Allocate::shallowCopy() const { + auto result = + IrBuilder::create(buffer_, memory_type_, shape_, zero_init_); + result->copyPredicatesFrom(this); + return result; +} + GridReduction::GridReduction( IrBuilderPasskey passkey, BinaryOpType reduction_op_type, @@ -523,6 +603,22 @@ GridReduction::GridReduction( "IR type only valid for Kernel container."); } +Expr* GridReduction::shallowCopy() const { + auto result = IrBuilder::create( + getReductionOpType(), + init(), + out(), + in(), + reduction_buffer_, + sync_buffer_, + entrance_index_, + entrances_, + isAllreduce()); + result->copyPredicatesFrom(this); + result->thread_predicate_ = thread_predicate_; + return result; +} + GroupedGridReduction::GroupedGridReduction( IrBuilderPasskey passkey, std::vector reduction_op_types, @@ -553,6 +649,23 @@ GroupedGridReduction::GroupedGridReduction( "IR type only valid for Kernel container."); } +Expr* GroupedGridReduction::shallowCopy() const { + auto result = IrBuilder::create( + getReductionOpTypes(), + initVals(), + outputs(), + inputs(), + reduction_buffers_, + sync_buffer_, + entrance_index_, + entrances_, + buffer_stride_, + isAllreduce()); + result->copyPredicatesFrom(this); + result->thread_predicate_ = thread_predicate_; + return result; +} + GridBroadcast::GridBroadcast( IrBuilderPasskey passkey, BroadcastOp* broadcast_op, @@ -567,6 +680,13 @@ GridBroadcast::GridBroadcast( "IR type only valid for Kernel container."); } +Expr* GridBroadcast::shallowCopy() const { + auto result = IrBuilder::create( + broadcast_op_, broadcast_buffer_, sync_buffer_); + result->copyPredicatesFrom(this); + return result; +} + GridWelford::GridWelford( IrBuilderPasskey passkey, WelfordOp* welford_op, @@ -589,6 +709,20 @@ GridWelford::GridWelford( "IR type only valid for Kernel container."); } +Expr* GridWelford::shallowCopy() const { + auto result = IrBuilder::create( + welford_op_, + var_buffer_, + avg_buffer_, + n_buffer_, + sync_buffer_, + entrance_index_, + entrances_); + result->copyPredicatesFrom(this); + result->thread_predicate_ = thread_predicate_; + return result; +} + GroupedGridWelford::GroupedGridWelford( IrBuilderPasskey passkey, std::vector output_vals, @@ -617,6 +751,22 @@ GroupedGridWelford::GroupedGridWelford( "IR type only valid for Kernel container."); } +Expr* GroupedGridWelford::shallowCopy() const { + auto result = IrBuilder::create( + outputVals(), + inputVals(), + initVals(), + reduction_buffers_, + sync_buffer_, + entrance_index_, + entrances_, + buffer_stride_, + isAllreduce()); + result->copyPredicatesFrom(this); + result->thread_predicate_ = thread_predicate_; + return result; +} + AllocateFusedReduction::AllocateFusedReduction( IrBuilderPasskey passkey, GridReduction* grid_reduction) @@ -657,6 +807,36 @@ AllocateFusedReduction::AllocateFusedReduction( "IR type only valid for Kernel container."); } +Expr* AllocateFusedReduction::shallowCopy() const { + if (grid_expr_->isA()) { + auto result = IrBuilder::create( + grid_expr_->as()); + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + return result; + } else if (grid_expr_->isA()) { + auto result = IrBuilder::create( + grid_expr_->as()); + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + return result; + } else if (grid_expr_->isA()) { + auto result = IrBuilder::create( + grid_expr_->as()); + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + return result; + } else if (grid_expr_->isA()) { + auto result = IrBuilder::create( + grid_expr_->as()); + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + return result; + } + TORCH_INTERNAL_ASSERT( + false, "Unknown reduction type in AllocateFusedReduction::shallowCopy"); +} + TensorIndex* AllocateFusedReduction::out() const { TORCH_INTERNAL_ASSERT(grid_expr_ != nullptr); if (grid_expr_->isA() || diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index 62b245772dd0..cd44e8d8e21b 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -94,7 +94,7 @@ class TORCH_CUDA_CU_API Predicate final : public Val { return expr_; } - Bool* thread_pred() { + Bool* thread_pred() const { TORCH_INTERNAL_ASSERT( ptype_ == PredicateType::Inline || ptype_ == PredicateType::Misaligned || ptype_ == PredicateType::Shift || @@ -199,6 +199,8 @@ class TORCH_CUDA_CU_API Allocate final : public Expr { Val* size, bool zero_init = false); + Expr* shallowCopy() const override; + Val* buffer() const { return buffer_; } @@ -251,6 +253,8 @@ class TORCH_CUDA_CU_API BlockSync final : public Expr { public: explicit BlockSync(IrBuilderPasskey passkey, bool war_sync = false); + Expr* shallowCopy() const override; + bool isWarHazardSync() const { return war_sync_; } @@ -265,6 +269,8 @@ class TORCH_CUDA_CU_API CpAsyncWait final : public Expr { public: explicit CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages = 0); + Expr* shallowCopy() const override; + //! Returns the remaining number of stages that are not synchronized //! after this op. unsigned int keepStages() const { @@ -282,6 +288,8 @@ class TORCH_CUDA_CU_API CpAsyncWait final : public Expr { class TORCH_CUDA_CU_API CpAsyncCommit final : public Expr { public: explicit CpAsyncCommit(IrBuilderPasskey passkey); + + Expr* shallowCopy() const override; }; // Synchronize all blocks in device, implies cooperative group launch is @@ -293,6 +301,8 @@ class TORCH_CUDA_CU_API GridSync final : public Expr { ParallelTypeBitmap sync_dims, Val* sync_buffer); + Expr* shallowCopy() const override; + ParallelTypeBitmap syncDims() const { return sync_dims_; } @@ -311,6 +321,8 @@ class TORCH_CUDA_CU_API GridSync final : public Expr { class TORCH_CUDA_CU_API InitMagicZero final : public Expr { public: explicit InitMagicZero(IrBuilderPasskey passkey); + + Expr* shallowCopy() const override; }; // Simply prints "UPDATE_MAGIC_ZERO" in the code in accordance with magic_zero @@ -318,6 +330,8 @@ class TORCH_CUDA_CU_API InitMagicZero final : public Expr { class TORCH_CUDA_CU_API UpdateMagicZero final : public Expr { public: explicit UpdateMagicZero(IrBuilderPasskey passkey); + + Expr* shallowCopy() const override; }; // TODO(kir): promote to IR node @@ -418,6 +432,8 @@ class TORCH_CUDA_CU_API ForLoop final : public Expr { ForLoop(IrBuilderPasskey passkey, const ForLoop* other); + Expr* shallowCopy() const override; + Val* index() const { return index_; } @@ -512,6 +528,8 @@ class TORCH_CUDA_CU_API IfThenElse final : public Expr { public: explicit IfThenElse(IrBuilderPasskey passkey, Predicate* cond); + Expr* shallowCopy() const override; + Scope& thenBody() { return then_body_; } @@ -557,6 +575,8 @@ class TORCH_CUDA_CU_API GridReduction final : public ReductionOp { Val* entrances, bool is_allreduce = false); + Expr* shallowCopy() const override; + Allocate* reduction_buffer() const { return reduction_buffer_; } @@ -579,8 +599,11 @@ class TORCH_CUDA_CU_API GridReduction final : public ReductionOp { return thread_predicate_; } - void setThreadPredicate(const ParallelTypeBitmap& thread_predicate) { - thread_predicate_ = thread_predicate; + GridReduction* withThreadPredicate( + const ParallelTypeBitmap& thread_predicate) { + auto result = shallowCopy()->as(); + result->thread_predicate_ = thread_predicate; + return result; } private: @@ -609,6 +632,8 @@ class TORCH_CUDA_CU_API GroupedGridReduction final : public GroupedReductionOp { Val* buffer_stride, bool is_allreduce = false); + Expr* shallowCopy() const override; + const std::vector& reduction_buffers() const { return reduction_buffers_; } @@ -639,8 +664,11 @@ class TORCH_CUDA_CU_API GroupedGridReduction final : public GroupedReductionOp { return thread_predicate_; } - void setThreadPredicate(const ParallelTypeBitmap& thread_predicate) { - thread_predicate_ = thread_predicate; + GroupedGridReduction* withThreadPredicate( + const ParallelTypeBitmap& thread_predicate) { + auto result = shallowCopy()->as(); + result->thread_predicate_ = thread_predicate; + return result; } private: @@ -671,6 +699,8 @@ class TORCH_CUDA_CU_API GridBroadcast final : public Expr { Allocate* broadcast_buffer, Allocate* sync_buffer); + Expr* shallowCopy() const override; + BroadcastOp* broadcast_op() const { return broadcast_op_; } @@ -710,6 +740,8 @@ class TORCH_CUDA_CU_API GridWelford final : public Expr { Val* entrance_index, Val* entrances); + Expr* shallowCopy() const override; + WelfordOp* welford_op() const { return welford_op_; } @@ -744,8 +776,10 @@ class TORCH_CUDA_CU_API GridWelford final : public Expr { return thread_predicate_; } - void setThreadPredicate(const ParallelTypeBitmap& thread_predicate) { - thread_predicate_ = thread_predicate; + GridWelford* withThreadPredicate(const ParallelTypeBitmap& thread_predicate) { + auto result = shallowCopy()->as(); + result->thread_predicate_ = thread_predicate; + return result; } private: @@ -777,6 +811,8 @@ class TORCH_CUDA_CU_API GroupedGridWelford final : public GroupedWelfordOp { Val* buffer_stride, bool is_allreduce = false); + Expr* shallowCopy() const override; + const std::array, 3>& reduction_buffers() const { return reduction_buffers_; } @@ -803,8 +839,11 @@ class TORCH_CUDA_CU_API GroupedGridWelford final : public GroupedWelfordOp { return thread_predicate_; } - void setThreadPredicate(const ParallelTypeBitmap& thread_predicate) { - thread_predicate_ = thread_predicate; + GroupedGridWelford* withThreadPredicate( + const ParallelTypeBitmap& thread_predicate) { + auto result = shallowCopy()->as(); + result->thread_predicate_ = thread_predicate; + return result; } private: @@ -839,6 +878,8 @@ class TORCH_CUDA_CU_API AllocateFusedReduction final : public Expr { IrBuilderPasskey passkey, GroupedGridWelford* grouped_grid_welford); + Expr* shallowCopy() const override; + Expr* gridExpr() const { return grid_expr_; } @@ -879,6 +920,8 @@ class TORCH_CUDA_CU_API PairSelect : public Expr { PairSelect(IrBuilderPasskey, Val* out, IntPair* in, Selection selection); + Expr* shallowCopy() const override; + Val* out() const { return out_; } @@ -914,6 +957,8 @@ class TORCH_CUDA_CU_API Swizzle2DInt : public Expr { Val* extent_y, Swizzle2DType swizzle_type); + Expr* shallowCopy() const override; + IntPair* out() const { return out_; } diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index f03f36bdaa71..9f72e977fcde 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -411,10 +411,12 @@ void IndexLowering::handleBlockReduction( ReductionOp* indexed_rop = IrBuilder::create( rop->getReductionOpType(), rop->init(), out, in, rop->isAllreduce()); if (rop->predicate()) { - indexed_rop->setPredicate(rop->predicate()); + indexed_rop = + indexed_rop->withPredicate(rop->predicate())->as(); } if (rop->writePredicate()) { - indexed_rop->setWritePredicate(rop->writePredicate()); + indexed_rop = indexed_rop->withWritePredicate(rop->writePredicate()) + ->as(); } pushBack(indexed_rop); @@ -493,13 +495,15 @@ void IndexLowering::handleGridReduction( n_entrances, rop->isAllreduce()); - grid_reduction->setThreadPredicate(thread_pred); + grid_reduction = grid_reduction->withThreadPredicate(thread_pred); if (rop->predicate()) { - grid_reduction->setPredicate(rop->predicate()); + grid_reduction = grid_reduction->withPredicate(rop->predicate()) + ->as(); } if (rop->writePredicate()) { - grid_reduction->setWritePredicate(rop->writePredicate()); + grid_reduction = grid_reduction->withWritePredicate(rop->writePredicate()) + ->as(); } pushBack(grid_reduction); @@ -556,10 +560,12 @@ void IndexLowering::handleBlockReduction( inputs, grouped_rop->isAllreduce()); if (grouped_rop->predicate()) { - indexed_rop->setPredicate(grouped_rop->predicate()); + indexed_rop = indexed_rop->withPredicate(grouped_rop->predicate()) + ->as(); } if (grouped_rop->writePredicate()) { - indexed_rop->setWritePredicate(grouped_rop->writePredicate()); + indexed_rop = indexed_rop->withWritePredicate(grouped_rop->writePredicate()) + ->as(); } pushBack(indexed_rop); @@ -638,13 +644,16 @@ void IndexLowering::handleGridReduction( work_buf_size_info.buffer_stride, grouped_rop->isAllreduce()); - grid_reduction->setThreadPredicate(thread_pred); + grid_reduction = grid_reduction->withThreadPredicate(thread_pred); if (grouped_rop->predicate()) { - grid_reduction->setPredicate(grouped_rop->predicate()); + grid_reduction = grid_reduction->withPredicate(grouped_rop->predicate()) + ->as(); } if (grouped_rop->writePredicate()) { - grid_reduction->setWritePredicate(grouped_rop->writePredicate()); + grid_reduction = + grid_reduction->withWritePredicate(grouped_rop->writePredicate()) + ->as(); } pushBack(grid_reduction); @@ -706,10 +715,11 @@ void IndexLowering::handle(const WelfordOp* wop) { wop->isAllreduce()); if (wop->predicate()) { - indexed_wop->setPredicate(wop->predicate()); + indexed_wop = indexed_wop->withPredicate(wop->predicate())->as(); } if (wop->writePredicate()) { - indexed_wop->setWritePredicate(wop->writePredicate()); + indexed_wop = + indexed_wop->withWritePredicate(wop->writePredicate())->as(); } // Serial welford @@ -785,22 +795,27 @@ void IndexLowering::handleGridWelford(WelfordOp* indexed_wop) { entrance_ind, n_entrances); - grid_welford->setThreadPredicate(thread_pred); + grid_welford = grid_welford->withThreadPredicate(thread_pred); const bool block_reduce_separated = out_domain->hasBlockReduction() && !indexed_wop->isAllreduce(); if (indexed_wop->predicate()) { if (block_reduce_separated) { - grid_welford->setPredicate(IrBuilder::create( - GpuLower::current()->kernel()->trueVal())); + grid_welford = grid_welford + ->withPredicate(IrBuilder::create( + GpuLower::current()->kernel()->trueVal())) + ->as(); } else { - grid_welford->setPredicate(indexed_wop->predicate()); + grid_welford = grid_welford->withPredicate(indexed_wop->predicate()) + ->as(); } } if (indexed_wop->writePredicate()) { - grid_welford->setWritePredicate(indexed_wop->writePredicate()); + grid_welford = + grid_welford->withWritePredicate(indexed_wop->writePredicate()) + ->as(); } if (block_reduce_separated) { @@ -945,13 +960,15 @@ void IndexLowering::handleGroupedGridWelford( work_buf_size_info.buffer_stride, op->isAllreduce()); - indexed_op->setThreadPredicate(thread_pred); + indexed_op = indexed_op->withThreadPredicate(thread_pred); if (op->predicate()) { - indexed_op->setPredicate(op->predicate()); + indexed_op = indexed_op->withPredicate(op->predicate()) + ->as(); } if (op->writePredicate()) { - indexed_op->setWritePredicate(op->writePredicate()); + indexed_op = indexed_op->withWritePredicate(op->writePredicate()) + ->as(); } pushBack(indexed_op); @@ -997,7 +1014,8 @@ void IndexLowering::handle(const BroadcastOp* bop) { const bool block_z = parallel_bitmap.get(ParallelType::BIDz); if (bop->predicate()) { - indexed_expr->setPredicate(bop->predicate()); + indexed_expr = + indexed_expr->withPredicate(bop->predicate())->as(); } const bool grid_broadcast_needed = block_x || block_y || block_z; @@ -1024,7 +1042,8 @@ void IndexLowering::handle(const BroadcastOp* bop) { indexed_expr, work_buffer, sync_buffer); if (bop->predicate()) { - grid_broadcast->setPredicate(bop->predicate()); + grid_broadcast = grid_broadcast->withPredicate(bop->predicate()) + ->as(); } pushBack(grid_broadcast); diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp index 989c00be81b7..4c7a73f5bfdf 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp @@ -20,7 +20,7 @@ namespace cuda { namespace { -class ConditionalFromPredicateModifier : public kir::IrVisitor { +class ConditionalFromPredicateModifier : public kir::ExprMutator { public: ConditionalFromPredicateModifier() = delete; @@ -32,11 +32,11 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { private: ConditionalFromPredicateModifier(const std::vector& exprs) { FUSER_PERF_SCOPE( - "GpuLower::Lower::ConditionalFromPredicateModifier::process"); - kir::IrVisitor::handle(exprs); + "ConditionalFromPredicateModifier::ConditionalFromPredicateModifier"); + traverseAndInsert(exprs); } - using kir::IrVisitor::handle; + using kir::ExprMutator::handle; void handle(Expr* expr) final { if (expr != nullptr && expr->predicate() != nullptr) { @@ -72,7 +72,7 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { TORCH_INTERNAL_ASSERT(conditional != nullptr); expr->predicate()->setValue(conditional); TORCH_INTERNAL_ASSERT(expr->predicate()->value() != nullptr); - setWritePredicate(expr, conditional); + setWritePredicate(expr); } // Note: [Predicate Inversion for CpAsync] @@ -101,7 +101,7 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { invertPredicateForGmemToSharedMemInitialize(expr); } - kir::IrVisitor::handle(expr); + kir::ExprMutator::handle(expr); } // Invert the predicate of given expr. @@ -123,7 +123,7 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { ir_utils::isCpAsyncInit(maybe_init.value()); } - void setWritePredicate(Expr* expr, Bool* read_cond) { + void setWritePredicate(Expr* expr) { if (expr->writePredicate() != nullptr) { auto write_cond = generateConditional(expr->writePredicate()); if (write_cond) { @@ -131,7 +131,7 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { } else { // If generateConditional returns null, it means no specific // predicate needs to be used. - expr->setWritePredicate(nullptr); + registerReplace(expr, expr->withWritePredicate(nullptr)); } } } @@ -150,7 +150,7 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { ite->predicate()->setValue(conditional); TORCH_INTERNAL_ASSERT(ite->predicate()->value() != nullptr); } - kir::IrVisitor::handle(ite); + kir::ExprMutator::handle(ite); } // Generate conditional according to PredicateType diff --git a/torch/csrc/jit/codegen/cuda/lower_shift.cpp b/torch/csrc/jit/codegen/cuda/lower_shift.cpp index 991ed5745c87..2a7c04243f4c 100644 --- a/torch/csrc/jit/codegen/cuda/lower_shift.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_shift.cpp @@ -17,7 +17,7 @@ namespace jit { namespace fuser { namespace cuda { -void ShiftPredicateInserter::insert( +Expr* ShiftPredicateInserter::insert( Expr* expr, const std::vector& loops, Bool* thread_pred, @@ -30,7 +30,7 @@ void ShiftPredicateInserter::insert( const bool needs_shift_predicate = gpu_lower->haloInfo()->needsShiftPredicate(out_tv->definition()); if (!needs_shift_predicate) { - return; + return expr; } // The conditional branches to create: @@ -57,8 +57,7 @@ void ShiftPredicateInserter::insert( // the expr with shift_pred. Since the expr is not shift, the // padding is safe to omit. if (lower_utils::hasBlockSync(expr, gpu_lower->threadPredMap())) { - expr->setPredicate(shift_pred); - return; + return expr->withPredicate(shift_pred); } auto shift_ite = IrBuilder::create(shift_pred); @@ -76,7 +75,7 @@ void ShiftPredicateInserter::insert( // No padding condition is required if this is within unswitch. if (within_unswitch) { - return; + return expr; } // Padding by zero @@ -89,6 +88,8 @@ void ShiftPredicateInserter::insert( bounds_ite->thenBody().push_back(pad_expr); // Insert the else block shift_ite->elseBody().push_back(bounds_ite); + + return expr; } int AxisHaloInfo::width() const { diff --git a/torch/csrc/jit/codegen/cuda/lower_shift.h b/torch/csrc/jit/codegen/cuda/lower_shift.h index 0cb3c3ea4457..f12410703d99 100644 --- a/torch/csrc/jit/codegen/cuda/lower_shift.h +++ b/torch/csrc/jit/codegen/cuda/lower_shift.h @@ -225,7 +225,7 @@ class ShiftPredicateInserter { //! the generated predicate. The branch structure is different from //! the usual predicated expression, so the insertion is also done //! here. - static void insert( + static Expr* insert( Expr* expr, const std::vector& loops, Bool* thread_pred, diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp index 8ba3909d19e7..97a544da9f64 100644 --- a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp @@ -54,6 +54,14 @@ bool isReductionInitExpr(const Expr* expr) { } // namespace +void UnrollPass::registerReplace( + Expr* reference, + Expr* new_expr, + kir::Scope* scope) { + kir::ExprMutator::registerReplace(reference, new_expr, scope); + GpuLower::current()->propagateExprInfo(reference, new_expr); +} + void UnrollPass::handle(Expr* expr) { if (ir_utils::isTvOp(expr)) { // If tv op, predicate it @@ -79,11 +87,16 @@ void UnrollPass::handle(Expr* expr) { non_trivial_pred_found_ = true; + Expr* expr_with_predicate = expr; + // When a predicate needs to account for ShiftOp, it is currently // taken care by its own function. if (GpuLower::current()->haloInfo()->needsShiftPredicate(expr)) { - ShiftPredicateInserter::insert( + expr_with_predicate = ShiftPredicateInserter::insert( expr, for_loops_, thread_pred, unswitched_loop_); + if (expr_with_predicate != expr) { + registerReplace(expr, expr_with_predicate, &for_loops_.back()->body()); + } return; } @@ -93,7 +106,7 @@ void UnrollPass::handle(Expr* expr) { ? thread_pred_expr : IrBuilder::create( PredicateType::ReductionWrite, expr, thread_pred); - expr->setWritePredicate(write_pred); + expr_with_predicate = expr_with_predicate->withWritePredicate(write_pred); } // For expr calling a device func with block sync, don't create @@ -103,7 +116,8 @@ void UnrollPass::handle(Expr* expr) { ? thread_pred_expr : IrBuilder::create( PredicateType::Inline, expr, thread_pred); - expr->setPredicate(pred); + expr_with_predicate = expr_with_predicate->withPredicate(pred); + registerReplace(expr, expr_with_predicate, &for_loops_.back()->body()); return; } @@ -135,7 +149,10 @@ void UnrollPass::handle(Expr* expr) { kir::ExprMutator::registerReplace( expr, inline_ite, &for_loops_.back()->body()); } - inline_ite->thenBody().push_back(expr); + if (expr != expr_with_predicate) { + GpuLower::current()->propagateExprInfo(expr, expr_with_predicate); + } + inline_ite->thenBody().push_back(expr_with_predicate); } else if (auto for_loop = dynamic_cast(expr)) { handle(for_loop); } diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.h b/torch/csrc/jit/codegen/cuda/lower_unroll.h index 14725c405b77..786e45115ba6 100644 --- a/torch/csrc/jit/codegen/cuda/lower_unroll.h +++ b/torch/csrc/jit/codegen/cuda/lower_unroll.h @@ -62,6 +62,8 @@ class TORCH_CUDA_CU_API UnrollPass : kir::ExprMutator { static bool canOmitElseClause(kir::ForLoop* fl); private: + void registerReplace(Expr* reference, Expr* new_expr, kir::Scope* scope); + // Generate the for Expr replacement map UnrollPass(const std::vector& exprs); diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index c3dc35c3412e..b98fbea263c3 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -446,8 +446,8 @@ class ReplaceExprInput : private kir::ExprMutator { // Copy predicates and register expression replacement void registerReplaceWithPredicate(Expr* old_expr, Expr* new_expr) { - new_expr->setPredicate(old_expr->predicate()); - new_expr->setWritePredicate(old_expr->writePredicate()); + new_expr = new_expr->withPredicate(old_expr->predicate()) + ->withWritePredicate(old_expr->writePredicate()); registerReplace(old_expr, new_expr); } diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 556ff4f00825..0252a7785d67 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -1001,6 +1001,9 @@ struct DummyExpr : public Expr { DummyExpr& operator=(const DummyExpr& other) = delete; DummyExpr(DummyExpr&& other) = delete; DummyExpr& operator=(DummyExpr&& other) = delete; + Expr* shallowCopy() const override { + return nullptr; + } }; TEST_F(NVFuserTest, FusionTopoSort_CUDA) {