diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index b794f1c20897..b52d8f474dde 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -2003,10 +2003,14 @@ class CudaKernelGenerator : private OptOutConstDispatch { ArgumentBuilder read_preds; ArgumentBuilder write_preds; + auto output_vals = grouped_gwop->outputVals(); + auto input_vals = grouped_gwop->inputVals(); + auto init_vals = grouped_gwop->initVals(); + for (const auto expr_index : c10::irange(grouped_gwop->numExprs())) { - const auto& output = grouped_gwop->outputVals().at(expr_index); - const auto& input = grouped_gwop->inputVals().at(expr_index); - const auto& init = grouped_gwop->initVals().at(expr_index); + const auto& output = output_vals.at(expr_index); + const auto& input = input_vals.at(expr_index); + const auto& init = init_vals.at(expr_index); for (const auto& group_index : c10::irange(index_replacement_maps.size())) { diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp index 251d688c5c8f..e307fa5ad61d 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.cpp +++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp @@ -313,6 +313,10 @@ void Val::constDispatch(T handler, const Val* val) { case ValType::TensorIndex: ptr(handler)->handle(val->as()); return; + case ValType::Attribute: + // Attribute Val is just a wrapper for non-IR data, so there is nothing to + // handle + return; default: break; } diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp index 42bc9309e5d1..64abef3d44a8 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp @@ -34,6 +34,8 @@ Statement::Statement(const Statement* src, IrCloner* ir_cloner) { ir_container_ = ir_cloner->container(); } +NVFUSER_DEFINE_CLONE(Statement) + void Statement::setName(IrContainerPasskey, StmtNameType name) { name_ = name; } @@ -98,6 +100,8 @@ Val::Val(IrBuilderPasskey passkey, ValType _vtype, DataType _dtype) Val::Val(const Val* src, IrCloner* ir_cloner) : Statement(src, ir_cloner), vtype_(src->vtype_), dtype_(src->dtype_) {} +NVFUSER_DEFINE_CLONE(Val) + const std::vector& Val::uses() const { if (vtype_ == ValType::TensorView) { if (!fusion()->isTVUseInfoValid() && !fusion()->isUpdatingTVUseInfo()) { @@ -319,7 +323,27 @@ Expr::Expr(IrBuilderPasskey passkey) : Statement(passkey) {} Expr::Expr(const Expr* src, IrCloner* ir_cloner) : Statement(src, ir_cloner), inputs_(ir_cloner->clone(src->inputs_)), - outputs_(ir_cloner->clone(src->outputs_)) {} + outputs_(ir_cloner->clone(src->outputs_)), + attributes_(ir_cloner->clone(src->attributes_)) {} + +Expr::Expr( + IrBuilderPasskey passkey, + std::vector inputs, + std::vector outputs, + std::vector attributes) + : Statement(passkey), + inputs_(std::move(inputs)), + outputs_(std::move(outputs)), + attributes_(std::move(attributes)) {} + +Expr* Expr::shallowCopy() const { + auto result = newObject(inputs(), outputs(), attributes()); + if (container()->isA()) { + result->predicate_ = predicate_; + result->write_predicate_ = write_predicate_; + } + return result; +} bool Expr::sameAs(const Statement* other) const { if (this == other) { @@ -333,7 +357,8 @@ bool Expr::sameAs(const Statement* other) const { return false; } if (inputs().size() != other_expr->inputs().size() || - outputs().size() != other_expr->outputs().size()) { + outputs().size() != other_expr->outputs().size() || + attributes().size() != other_expr->attributes().size()) { return false; } for (const auto i : c10::irange(inputs().size())) { @@ -341,6 +366,11 @@ bool Expr::sameAs(const Statement* other) const { return false; } } + for (const auto i : c10::irange(attributes().size())) { + if (!attribute(i)->sameAs(other_expr->attribute(i))) { + return false; + } + } return true; } @@ -380,13 +410,6 @@ Expr* Expr::withWritePredicate(kir::Predicate* predicate) { return result; } -void Expr::copyPredicatesFrom(const Expr* expr) { - if (container()->isA()) { - predicate_ = expr->predicate_; - write_predicate_ = expr->write_predicate_; - } -} - } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index f612a72352f3..e6b4f35d78d1 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -70,6 +71,14 @@ class ExprPasskey { TORCH_CUDA_CU_API void swap(Fusion& a, Fusion& b) noexcept; +#define NVFUSER_DECLARE_CLONE \ + virtual Statement* clone(IrCloner* ir_cloner) const override; + +#define NVFUSER_DEFINE_CLONE(ClassName) \ + Statement* ClassName::clone(IrCloner* ir_cloner) const { \ + return IrBuilder::clone(this, ir_cloner); \ + } + //! Statement is the highest level node representation. Everything that is //! considered "IR" will be derived from this class at some point. Both Values //! and Expr's are a Statement. If there will ever be any more fundamental @@ -159,6 +168,8 @@ class TORCH_CUDA_CU_API Statement : public NonCopyable, public PolymorphicBase { std::string toString() const; std::string toInlineString() const; + virtual Statement* clone(IrCloner* ir_cloner) const; + protected: Statement(IrBuilderPasskey); @@ -353,6 +364,8 @@ class TORCH_CUDA_CU_API Val : public Statement { void resolveIndexDtype(); + NVFUSER_DECLARE_CLONE + protected: friend Fusion; @@ -391,6 +404,31 @@ class TORCH_CUDA_CU_API Val : public Statement { int evaluator_index_ = -1; }; +//! A Val object that stores a plain data. Note that this class is only intended +//! to hold non-IR data, such as DataType, std::vector, etc. Please don't +//! use this class to hold IR nodes or their pointers. +template +class TORCH_CUDA_CU_API Attribute : public Val { + public: + T value; + Attribute(IrBuilderPasskey passkey, const T& value) + : Val(passkey, ValType::Attribute), value(value) {} + Attribute(const Attribute* src, IrCloner* ir_cloner) + : Val(src, ir_cloner), value(src->value) {} + template + Attribute(IrBuilderPasskey passkey, Args... args) + : Val(passkey, ValType::Attribute), value(std::forward(args)...) {} + + NVFUSER_DECLARE_CLONE + + bool sameAs(const Statement* other) const override { + if (auto pv = dynamic_cast(other)) { + return pv->value == value; + } + return false; + } +}; + //! A Expr represents a "computation." These are functions that takes inputs //! and produce outputs, inputs and outputs all being Vals. There are //! specializations of BinaryOp which takes 2 inputs and produces 1 output, and @@ -436,12 +474,25 @@ class TORCH_CUDA_CU_API Expr : public Statement { Expr(const Expr* src, IrCloner* ir_cloner); + Expr( + IrBuilderPasskey, + std::vector inputs, + std::vector outputs, + std::vector attributes); + // Creates a new instance of the expression with all its field copied. // Note that unlike IrCloner, this function only do a shallow copy - virtual Expr* shallowCopy() const = 0; + Expr* shallowCopy() const; bool sameAs(const Statement* other) const override; + // Creates a new instance of the same expression type with the given inputs, + // outputs, and attributes. + virtual Expr* newObject( + std::vector inputs, + std::vector outputs, + std::vector attributes) const = 0; + // Input/output accessors const auto& inputs() const { return inputs_; @@ -451,12 +502,24 @@ class TORCH_CUDA_CU_API Expr : public Statement { return outputs_; } + const auto& attributes() const { + return attributes_; + } + auto input(size_t index) const { - return inputs_[index]; + return inputs_.at(index); } auto output(size_t index) const { - return outputs_[index]; + return outputs_.at(index); + } + + auto attribute(size_t index) const { + return attributes_.at(index); + } + + auto attributeVal(size_t index) const { + return dynamic_cast(attributes_.at(index)); } // Dispatch functions, definitions in dispatch.cpp @@ -494,8 +557,6 @@ class TORCH_CUDA_CU_API Expr : public Statement { // TODO: Protect based on being in kernel container void setWritePredicate(kir::Predicate* write_predicate); - void copyPredicatesFrom(const Expr* expr); - // TODO: Add Fusion passkey void addInput(Val* input) { TORCH_INTERNAL_ASSERT(input != nullptr); @@ -508,6 +569,11 @@ class TORCH_CUDA_CU_API Expr : public Statement { outputs_.push_back(output); } + // TODO: Add Fusion passkey + void addAttribute(Statement* attr) { + attributes_.push_back(attr); + } + ExprPasskey exprPasskey() { return ExprPasskey(); } @@ -515,6 +581,7 @@ class TORCH_CUDA_CU_API Expr : public Statement { private: std::vector inputs_; std::vector outputs_; + std::vector attributes_; kir::Predicate* predicate_ = nullptr; @@ -530,6 +597,24 @@ bool Val::isDefinitionType() const { return false; } +#define NVFUSER_DECLARE_CLONE_AND_CREATE \ + virtual Statement* clone(IrCloner* ir_cloner) const override; \ + virtual Expr* newObject( \ + std::vector inputs, \ + std::vector outputs, \ + std::vector attributes) const override; + +#define NVFUSER_DEFINE_CLONE_AND_CREATE(ClassName) \ + Statement* ClassName::clone(IrCloner* ir_cloner) const { \ + return IrBuilder::clone(this, ir_cloner); \ + } \ + Expr* ClassName::newObject( \ + std::vector inputs, \ + std::vector outputs, \ + std::vector attributes) const { \ + return IrBuilder::create(inputs, outputs, attributes); \ + } + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/ir_builder.cpp b/torch/csrc/jit/codegen/cuda/ir_builder.cpp index fe60fa94a2d8..d704fe4ffef5 100644 --- a/torch/csrc/jit/codegen/cuda/ir_builder.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_builder.cpp @@ -11,75 +11,6 @@ namespace jit { namespace fuser { namespace cuda { -//! Clone an IR node, forwarding the arguments to the IrCloner constructor. -template -T* IrBuilder::clone(const T* src, IrCloner* ir_cloner) { - TORCH_INTERNAL_ASSERT( - ir_cloner != nullptr, - "Cannot use create when a cloner object is set. Use clone."); - - TORCH_INTERNAL_ASSERT( - ir_cloner->container() != nullptr, - "Cloner doesn't have a valid container to store cloned object."); - - T* dest = new T(src, ir_cloner); - const Statement* src_stmt = dynamic_cast(src); - Statement* dest_stmt = dynamic_cast(dest); - - auto dest_container = ir_cloner->container(); - auto src_container = src_stmt->container(); - - dest_container->registerStmt(IrBuilderPasskey(dest_container), dest_stmt); - - if (src_container != dest_container) { - dest_stmt->setName(IrBuilderPasskey(dest_container), src_stmt->name()); - } - - ir_cloner->registerClone(src_stmt, dest_stmt); - - return dest; -} - -#define IR_BUILDER_INSTANTIATE(T) \ - template T* IrBuilder::clone(const T* src, IrCloner* ir_cloner); - -// Vals -IR_BUILDER_INSTANTIATE(IterDomain) -IR_BUILDER_INSTANTIATE(TensorDomain) -IR_BUILDER_INSTANTIATE(TensorView) -IR_BUILDER_INSTANTIATE(Bool) -IR_BUILDER_INSTANTIATE(Float) -IR_BUILDER_INSTANTIATE(Double) -IR_BUILDER_INSTANTIATE(Int) -IR_BUILDER_INSTANTIATE(ComplexDouble) -IR_BUILDER_INSTANTIATE(NamedScalar) - -// Exprs -IR_BUILDER_INSTANTIATE(Split) -IR_BUILDER_INSTANTIATE(Merge) -IR_BUILDER_INSTANTIATE(Swizzle2D) -IR_BUILDER_INSTANTIATE(TransposeOp) -IR_BUILDER_INSTANTIATE(ExpandOp) -IR_BUILDER_INSTANTIATE(ShiftOp) -IR_BUILDER_INSTANTIATE(GatherOp) -IR_BUILDER_INSTANTIATE(ViewAsScalar) -IR_BUILDER_INSTANTIATE(ViewOp) -IR_BUILDER_INSTANTIATE(FullOp) -IR_BUILDER_INSTANTIATE(ARangeOp) -IR_BUILDER_INSTANTIATE(EyeOp) -IR_BUILDER_INSTANTIATE(UnaryOp) -IR_BUILDER_INSTANTIATE(BinaryOp) -IR_BUILDER_INSTANTIATE(TernaryOp) -IR_BUILDER_INSTANTIATE(SelectOp) -IR_BUILDER_INSTANTIATE(RNGOp) -IR_BUILDER_INSTANTIATE(ReductionOp) -IR_BUILDER_INSTANTIATE(GroupedReductionOp) -IR_BUILDER_INSTANTIATE(WelfordOp) -IR_BUILDER_INSTANTIATE(LoadStoreOp) -IR_BUILDER_INSTANTIATE(MmaOp) -IR_BUILDER_INSTANTIATE(BroadcastOp) -IR_BUILDER_INSTANTIATE(SqueezeOp) - Val* IrBuilder::newResult(DataType dtype) { switch (dtype) { case DataType::Bool: diff --git a/torch/csrc/jit/codegen/cuda/ir_builder.h b/torch/csrc/jit/codegen/cuda/ir_builder.h index 77693fb5a26b..7791debbf52a 100644 --- a/torch/csrc/jit/codegen/cuda/ir_builder.h +++ b/torch/csrc/jit/codegen/cuda/ir_builder.h @@ -111,6 +111,9 @@ class TORCH_CUDA_CU_API SimplifyingIrBuilder : public IrBuilder { static Val* minExpr(Val* lhs, Val* rhs); }; +template +NVFUSER_DEFINE_CLONE(FloatingPoint
) + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp index 93caef14b5cb..7bdbfc3774e1 100644 --- a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp @@ -21,12 +21,7 @@ Statement* IrCloner::clone(const Statement* statement) { if (it != clones_map_.end()) { return it->second; } else { - // Clone the new node, saving/restoring this->clone_ - // since the cloning can be reentrant - auto saved_clone = clone_; - handle(statement); - auto new_node = clone_; - clone_ = saved_clone; + auto new_node = handle(statement); // The base cloning constructor (Statement) should have // registered the new node. Failure to do so indicates @@ -44,148 +39,8 @@ void IrCloner::registerClone(const Statement* src, Statement* clone) { TORCH_CHECK(clones_map_.insert({src, clone}).second); } -void IrCloner::handle(const Statement* s) { - OptInConstDispatch::handle(s); -} - -void IrCloner::handle(const Val* v) { - OptInConstDispatch::handle(v); -} - -void IrCloner::handle(const Expr* e) { - OptInConstDispatch::handle(e); -} - -void IrCloner::handle(const TensorDomain* td) { - clone_ = IrBuilder::clone(td, this); -} - -void IrCloner::handle(const IterDomain* id) { - clone_ = IrBuilder::clone(id, this); -} - -void IrCloner::handle(const Bool* b) { - clone_ = IrBuilder::clone(b, this); -} - -void IrCloner::handle(const Float* f) { - clone_ = IrBuilder::clone(f, this); -} - -void IrCloner::handle(const Double* d) { - clone_ = IrBuilder::clone(d, this); -} - -void IrCloner::handle(const Int* i) { - clone_ = IrBuilder::clone(i, this); -} - -void IrCloner::handle(const ComplexDouble* c) { - clone_ = IrBuilder::clone(c, this); -} - -void IrCloner::handle(const NamedScalar* named_scalar) { - clone_ = IrBuilder::clone(named_scalar, this); -} - -void IrCloner::handle(const TensorView* tv) { - clone_ = IrBuilder::clone(tv, this); -} - -void IrCloner::handle(const FullOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const ARangeOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const EyeOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const UnaryOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const BinaryOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const TernaryOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const SelectOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const RNGOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const BroadcastOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const SqueezeOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const ReductionOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const GroupedReductionOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const WelfordOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const LoadStoreOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const MmaOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const TransposeOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const ExpandOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const ShiftOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const GatherOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const ViewAsScalar* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const ViewOp* op) { - clone_ = IrBuilder::clone(op, this); -} - -void IrCloner::handle(const Split* split) { - clone_ = IrBuilder::clone(split, this); -} - -void IrCloner::handle(const Merge* merge) { - clone_ = IrBuilder::clone(merge, this); -} - -void IrCloner::handle(const Swizzle2D* swizzle) { - clone_ = IrBuilder::clone(swizzle, this); +Statement* IrCloner::handle(const Statement* s) { + return s->clone(this); } TensorView* RecomputeTv::recompute(TensorView* tv) { @@ -232,11 +87,18 @@ RecomputeTv::RecomputeTv(Fusion* fusion, std::vector exprs) } // Clone the expressions for (auto expr : exprs) { - IrCloner::handle(expr); + handle(expr); + } +} + +Statement* RecomputeTv::handle(const Statement* s) { + if (s->isA()) { + return handle(s->as()); } + return s->clone(this); } -void RecomputeTv::handle(const TensorDomain* td) { +Statement* RecomputeTv::handle(const TensorDomain* td) { // Make sure to recompute the history of the iteration domains, explicitly go // through the expressions and send them to IrCloner. auto exprs = @@ -245,7 +107,7 @@ void RecomputeTv::handle(const TensorDomain* td) { for (auto expr : exprs) { IrCloner::handle(expr); } - IrCloner::handle(td); + return IrCloner::handle(td); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.h b/torch/csrc/jit/codegen/cuda/ir_cloner.h index 790e44f9a108..a2b6e1cf7c0d 100644 --- a/torch/csrc/jit/codegen/cuda/ir_cloner.h +++ b/torch/csrc/jit/codegen/cuda/ir_cloner.h @@ -20,13 +20,14 @@ class IrContainer; //! Fusion copy operations and the and limited scope of RecomputeTv below. //! It is not intended for any other uses. //! -class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch { +class TORCH_CUDA_CU_API IrCloner { friend class Statement; friend class IrBuilder; public: // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) explicit IrCloner(IrContainer* container); + virtual ~IrCloner() {} Statement* clone(const Statement* statement); @@ -53,47 +54,7 @@ class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch { protected: void registerClone(const Statement* src, Statement* clone); - - void handle(const Statement*) override; - void handle(const Val*) override; - void handle(const Expr*) override; - - void handle(const TensorDomain*) override; - void handle(const TensorView*) override; - void handle(const IterDomain*) override; - - void handle(const Bool*) override; - void handle(const Float*) override; - void handle(const Double*) override; - void handle(const Int*) override; - void handle(const ComplexDouble*) override; - void handle(const NamedScalar*) override; - - void handle(const FullOp*) override; - void handle(const ARangeOp*) override; - void handle(const EyeOp*) override; - void handle(const UnaryOp*) override; - void handle(const BinaryOp*) override; - void handle(const TernaryOp*) override; - void handle(const SelectOp*) override; - void handle(const RNGOp*) override; - void handle(const BroadcastOp*) override; - void handle(const SqueezeOp*) override; - void handle(const ReductionOp*) override; - void handle(const GroupedReductionOp*) override; - void handle(const WelfordOp*) override; - void handle(const LoadStoreOp*) override; - void handle(const MmaOp*) override; - void handle(const TransposeOp*) override; - void handle(const ExpandOp*) override; - void handle(const ShiftOp*) override; - void handle(const GatherOp*) override; - void handle(const ViewAsScalar*) override; - void handle(const ViewOp*) override; - - void handle(const Split*) override; - void handle(const Merge*) override; - void handle(const Swizzle2D*) override; + virtual Statement* handle(const Statement* s); protected: // We keep track of the original -> clone map so we don't @@ -104,11 +65,6 @@ class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch { // The destination Fusion container IrContainer* ir_container_ = nullptr; - // The dispatch interface doesn't allow returning values from - // individual `handle()` methods, so they are storing the - // result here - Statement* clone_ = nullptr; - // Builder to make all the new nodes IrBuilder builder_; }; @@ -123,12 +79,44 @@ class RecomputeTv : private IrCloner { private: RecomputeTv(Fusion* fusion, std::vector exprs); - - void handle(const TensorDomain*) final; + virtual Statement* handle(const Statement* s) override; + Statement* handle(const TensorDomain*); Fusion* fusion_; }; +//! Clone an IR node, forwarding the arguments to the IrCloner constructor. +template +T* IrBuilder::clone(const T* src, IrCloner* ir_cloner) { + TORCH_INTERNAL_ASSERT( + ir_cloner != nullptr, + "Cannot use create when a cloner object is set. Use clone."); + + TORCH_INTERNAL_ASSERT( + ir_cloner->container() != nullptr, + "Cloner doesn't have a valid container to store cloned object."); + + T* dest = new T(src, ir_cloner); + const Statement* src_stmt = dynamic_cast(src); + Statement* dest_stmt = dynamic_cast(dest); + + auto dest_container = ir_cloner->container(); + auto src_container = src_stmt->container(); + + dest_container->registerStmt(IrBuilderPasskey(dest_container), dest_stmt); + + if (src_container != dest_container) { + dest_stmt->setName(IrBuilderPasskey(dest_container), src_stmt->name()); + } + + ir_cloner->registerClone(src_stmt, dest_stmt); + + return dest; +} + +template +NVFUSER_DEFINE_CLONE(Attribute) + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index 12138ac6fb07..7e59416f434c 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -38,6 +38,8 @@ class TORCH_CUDA_CU_API Bool : public Val { Bool(const Bool* src, IrCloner* ir_cloner); + NVFUSER_DECLARE_CLONE + bool isSymbolic() const { return !(maybe_value_.has_value()); } @@ -76,6 +78,8 @@ class TORCH_CUDA_CU_API FloatingPoint : public Val { FloatingPoint(const FloatingPoint* src, IrCloner* ir_cloner) : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {} + NVFUSER_DECLARE_CLONE + bool isSymbolic() const { return !(maybe_value_.has_value()); } @@ -120,6 +124,8 @@ class TORCH_CUDA_CU_API Int : public Val { Int(const Int* src, IrCloner* ir_cloner); + NVFUSER_DECLARE_CLONE + bool isSymbolic() const { return !(maybe_value_.has_value()); } @@ -153,6 +159,8 @@ class TORCH_CUDA_CU_API ComplexDouble : public Val { ComplexDouble(const ComplexDouble* src, IrCloner* ir_cloner); + NVFUSER_DECLARE_CLONE + bool isSymbolic() const { return !(maybe_value_.has_value()); } @@ -230,6 +238,8 @@ class TORCH_CUDA_CU_API TensorView : public Val { TensorView(const TensorView* src, IrCloner* ir_cloner); + NVFUSER_DECLARE_CLONE + TensorDomain* domain() const { return domain_; } diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index 1c9994b7d337..cc70c5814239 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -32,33 +32,29 @@ bool areEqualScalars(Val* v1, Val* v2); class TORCH_CUDA_CU_API FullOp : public Expr { public: - FullOp(IrBuilderPasskey, Val* out, Val* fill_value, DataType dtype); + using Expr::Expr; - FullOp(const FullOp* src, IrCloner* ir_cloner); + FullOp(IrBuilderPasskey, Val* out, Val* fill_value, DataType dtype); - Expr* shallowCopy() const override; - - bool sameAs(const Statement* other) const override; + NVFUSER_DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "FullOp"; } DataType dtype() const { - return dtype_; + return attribute(0)->as>()->value; } Val* getFillValue() const { - return fill_value_; + return inputs().back(); } - - private: - const DataType dtype_; - Val* fill_value_; }; class TORCH_CUDA_CU_API SelectOp : public Expr { public: + using Expr::Expr; + SelectOp( IrBuilderPasskey, Val* out, @@ -66,30 +62,25 @@ class TORCH_CUDA_CU_API SelectOp : public Expr { IterDomain* select_id, Val* index); - SelectOp(const SelectOp* src, IrCloner* ir_cloner); - - Expr* shallowCopy() const override; - - bool sameAs(const Statement* other) const override; + NVFUSER_DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "SelectOp"; } - std::unordered_map getIndexOverridingMap() const { - return {{select_id_, input(1)}}; - } - IterDomain* getSelectAxis() const { - return select_id_; + return attribute(0)->as(); } - private: - IterDomain* select_id_; + std::unordered_map getIndexOverridingMap() const { + return {{getSelectAxis(), input(1)}}; + } }; class TORCH_CUDA_CU_API ARangeOp : public Expr { public: + using Expr::Expr; + ARangeOp( IrBuilderPasskey, Val* out, @@ -99,46 +90,31 @@ class TORCH_CUDA_CU_API ARangeOp : public Expr { DataType dtype, Val* linear_index = nullptr); - ARangeOp(const ARangeOp* src, IrCloner* ir_cloner); - - Expr* shallowCopy() const override; - - bool sameAs(const Statement* other) const override; + NVFUSER_DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "ARangeOp"; } DataType dtype() const { - return dtype_; + return attribute(0)->as>()->value; } Val* start() const { - return start_; + return input(0); } Val* end() const { - return end_; + return input(1); } Val* step() const { - return step_; + return input(2); } Val* getLinearLogicalIndex() const { - return linear_index_; - } - - void setLinearIndex(Val* index) { - linear_index_ = index; + return attributeVal(1); } - - private: - const DataType dtype_; - Val* start_; - Val* end_; - Val* step_; - Val* linear_index_ = nullptr; }; // Tensor factory for generating identity matrices like @@ -161,6 +137,8 @@ class TORCH_CUDA_CU_API ARangeOp : public Expr { // [0, 0, 1, 0]] class TORCH_CUDA_CU_API EyeOp : public Expr { public: + using Expr::Expr; + EyeOp( IrBuilderPasskey, Val* out, @@ -168,40 +146,23 @@ class TORCH_CUDA_CU_API EyeOp : public Expr { Val* index1 = nullptr, Val* index2 = nullptr); - EyeOp(const EyeOp* src, IrCloner* ir_cloner); - - Expr* shallowCopy() const override; - - bool sameAs(const Statement* other) const override; + NVFUSER_DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "EyeOp"; } DataType dtype() const { - return dtype_; + return attribute(0)->as>()->value; } Val* getIndex1() const { - return index1_; - } - - void setIndex1(Val* index) { - index1_ = index; + return attributeVal(1); } Val* getIndex2() const { - return index2_; + return attributeVal(2); } - - void setIndex2(Val* index) { - index2_ = index; - } - - private: - const DataType dtype_; - Val* index1_ = nullptr; - Val* index2_ = nullptr; }; //! A specialization for Unary operations. Unary operations take in a single @@ -212,38 +173,26 @@ class TORCH_CUDA_CU_API EyeOp : public Expr { //! 4) split/merge class TORCH_CUDA_CU_API UnaryOp : public Expr { public: - UnaryOp( - IrBuilderPasskey, - UnaryOpType type, - Val* out, - Val* in, - int rng_offset = -1); + using Expr::Expr; - UnaryOp(const UnaryOp* src, IrCloner* ir_cloner); + UnaryOp(IrBuilderPasskey, UnaryOpType type, Val* out, Val* in); - Expr* shallowCopy() const override; + NVFUSER_DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "UnaryOp"; } Val* out() const { - return out_; + return output(0); } Val* in() const { - return in_; + return input(0); } UnaryOpType getUnaryOpType() const { - return unary_op_type_; + return attribute(0)->as>()->value; } - - bool sameAs(const Statement* other) const override; - - private: - const UnaryOpType unary_op_type_; - Val* const out_ = nullptr; - Val* const in_ = nullptr; }; //! A specialization for Binary operations. Binary operations take in two inputs @@ -252,43 +201,89 @@ class TORCH_CUDA_CU_API UnaryOp : public Expr { //! 2) LT (A < B) class TORCH_CUDA_CU_API BinaryOp : public Expr { public: - BinaryOp(IrBuilderPasskey, BinaryOpType type, Val* out, Val* lhs, Val* rhs); + using Expr::Expr; - BinaryOp(const BinaryOp* src, IrCloner* ir_cloner); + BinaryOp(IrBuilderPasskey, BinaryOpType type, Val* out, Val* lhs, Val* rhs); - Expr* shallowCopy() const override; + NVFUSER_DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "BinaryOp"; } Val* out() const { - return out_; + return output(0); } Val* lhs() const { - return lhs_; + return input(0); } Val* rhs() const { - return rhs_; + return input(1); } BinaryOpType getBinaryOpType() const { - return binary_op_type_; + return attribute(0)->as>()->value; } +}; - bool sameAs(const Statement* other) const override; +class TORCH_CUDA_CU_API TernaryOp : public Expr { + public: + using Expr::Expr; - private: - const BinaryOpType binary_op_type_; - Val* const out_ = nullptr; - Val* const lhs_ = nullptr; - Val* const rhs_ = nullptr; + TernaryOp( + IrBuilderPasskey, + TernaryOpType type, + Val* out, + Val* in1, + Val* in2, + Val* in3); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + virtual const char* getOpString() const override { + return "TernaryOp"; + } + + Val* out() const { + return output(0); + } + + Val* in1() const { + return input(0); + } + Val* in2() const { + return input(1); + } + Val* in3() const { + return input(2); + } + + TernaryOpType getTernaryOpType() const { + return attribute(0)->as>()->value; + } }; //! A specialization for random number generator (RNG) operations. RNG //! operations take in no tensor input and produce a single output. class TORCH_CUDA_CU_API RNGOp : public Expr { + size_t getOutputDims() const; + public: + struct Attributes { + RNGOpType rtype; + DataType dtype; + int rng_offset; + + // TODO: Enable the following in C++20: + // bool operator==(const Attributes &other) const = default; + bool operator==(const Attributes& other) const { + return rtype == other.rtype && dtype == other.dtype && + rng_offset == other.rng_offset; + } + }; + + using Expr::Expr; + RNGOp( IrBuilderPasskey, RNGOpType type, @@ -298,66 +293,51 @@ class TORCH_CUDA_CU_API RNGOp : public Expr { int rng_offset = 0, Val* philox_index = nullptr); - RNGOp(const RNGOp* src, IrCloner* ir_cloner); - - Expr* shallowCopy() const override; + NVFUSER_DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "RNGOp"; } RNGOpType getRNGOpType() const { - return rng_op_type_; + return attribute(0)->as>()->value.rtype; } DataType dtype() const { - return dtype_; + return attribute(0)->as>()->value.dtype; } int getRNGOffset() const { - return rng_offset_; + return attribute(0)->as>()->value.rng_offset; } void setRNGOffset(int val) { - rng_offset_ = val; + attribute(0)->as>()->value.rng_offset = val; } - const std::vector& getParameters() const { - return parameters_; + std::vector getParameters() const { + return {inputs().begin() + getOutputDims(), inputs().end()}; } - const std::vector& getShape() const { - return shape_; + std::vector getShape() const { + return {inputs().begin(), inputs().begin() + getOutputDims()}; } Val* getPhiloxIndex() const { - return philox_index_; - } - - void setPhiloxIndex(Val* index) { - philox_index_ = index; + return attributeVal(1); } int getPhiloxMultiple() const { - return dtype_ == DataType::Double ? 2 : 4; + return dtype() == DataType::Double ? 2 : 4; } - - bool sameAs(const Statement* other) const override; - - private: - const RNGOpType rng_op_type_; - const DataType dtype_; - std::vector parameters_; - std::vector shape_; - int rng_offset_ = -1; - // The index used to feed philox's subsequence and component - Val* philox_index_ = nullptr; }; //! Broadcast in to match out. is_broadcast_dims are relative to out. Where //! is_broadcast_dims.size() == out->nDims(). class TORCH_CUDA_CU_API BroadcastOp : public Expr { public: + using Expr::Expr; + //! \param out The output tensor //! \param in The input tensor //! \param is_broadcast_dims True when output dim is a new broadcast domain @@ -367,42 +347,32 @@ class TORCH_CUDA_CU_API BroadcastOp : public Expr { Val* in, std::vector is_broadcast_dims); - BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner); + NVFUSER_DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "BroadcastOp"; } - Expr* shallowCopy() const override; - Val* out() const { - return out_; + return output(0); } Val* in() const { - return in_; + return input(0); } bool isBroadcastDim(size_t dim) const { - return is_broadcast_dims_.at(dim); - } - - const std::vector& getBroadcastDimFlags() const { - return is_broadcast_dims_; + return getBroadcastDimFlags().at(dim); } - bool sameAs(const Statement* other) const override; - - private: - Val* const out_ = nullptr; - Val* const in_ = nullptr; - //! The same list passed to the broadcast arithmetic op. Each //! element corresponds to an IterDomain of the output tensor and is //! true when the IterDomain is a new broadcast domain. Note //! that the output tensor may have other broadcast domains whose //! flags are false because the input tensor may already have //! broadcast domains. - const std::vector is_broadcast_dims_; + const std::vector& getBroadcastDimFlags() const { + return attribute(0)->as>>()->value; + } }; //! Squeeze in to match out. is_squeeze_dims are relative to in. Where @@ -410,6 +380,8 @@ class TORCH_CUDA_CU_API BroadcastOp : public Expr { //! broadcast. class TORCH_CUDA_CU_API SqueezeOp : public Expr { public: + using Expr::Expr; + //! \param out The output tensor //! \param in The input tensor //! \param is_squeeze_dims True when input dim is a removed broadcast domain @@ -419,42 +391,32 @@ class TORCH_CUDA_CU_API SqueezeOp : public Expr { Val* in, std::vector is_broadcast_dims); - SqueezeOp(const SqueezeOp* src, IrCloner* ir_cloner); + NVFUSER_DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "SqueezeOp"; } - Expr* shallowCopy() const override; - Val* out() const { - return out_; + return output(0); } Val* in() const { - return in_; + return input(0); } bool isSqueezeDim(size_t dim) const { - return is_squeeze_dims_.at(dim); + return getSqueezeDimFlags().at(dim); } - const std::vector& getSqueezeDimFlags() const { - return is_squeeze_dims_; - } - - bool sameAs(const Statement* other) const override; - - private: - Val* const out_ = nullptr; - Val* const in_ = nullptr; - //! The same list passed to the squeeze arithmetic op. Each //! element corresponds to an IterDomain of the input tensor and is //! true when the IterDomain is a broadcast domain that is removed in the //! output. Note that the output tensor may still contain broadcast domains //! because the input tensor may have broadcast domains that we don't want to //! remove (false flag). - const std::vector is_squeeze_dims_; + const std::vector& getSqueezeDimFlags() const { + return attribute(0)->as>>()->value; + } }; //! Reduction operation. Out is first initialized to _init. Then @@ -464,6 +426,8 @@ class TORCH_CUDA_CU_API SqueezeOp : public Expr { //! non-reduction/non-broadcast dimensions. class TORCH_CUDA_CU_API ReductionOp : public Expr { public: + using Expr::Expr; + ReductionOp( IrBuilderPasskey, BinaryOpType reduction_op_type, @@ -472,41 +436,29 @@ class TORCH_CUDA_CU_API ReductionOp : public Expr { Val* in, bool is_allreduce = false); - ReductionOp(const ReductionOp* src, IrCloner* ir_cloner); + NVFUSER_DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "ReductionOp"; } - Expr* shallowCopy() const override; - Val* out() const { - return out_; + return output(0); } Val* in() const { - return in_; + return input(0); } Val* init() const { - return init_; + return attributeVal(0); } BinaryOpType getReductionOpType() const { - return reduction_op_type_; + return attribute(1)->as>()->value; } bool isAllreduce() const { - return is_allreduce_; + return attribute(2)->as>()->value; } - - bool sameAs(const Statement* other) const override; - - private: - const BinaryOpType reduction_op_type_; - Val* const init_ = nullptr; - Val* const out_ = nullptr; - Val* const in_ = nullptr; - //! True if broadcast is fused - bool is_allreduce_ = false; }; //! Grouped reduction operation for horizontal fusions. It works like @@ -517,61 +469,57 @@ class TORCH_CUDA_CU_API ReductionOp : public Expr { //! significant performance impact. class TORCH_CUDA_CU_API GroupedReductionOp : public Expr { public: + using Expr::Expr; + GroupedReductionOp( IrBuilderPasskey, - std::vector reduction_op_type, + std::vector reduction_op_types, std::vector init, std::vector out, std::vector in, bool is_allreduce = false); - GroupedReductionOp(const GroupedReductionOp* src, IrCloner* ir_cloner); + NVFUSER_DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "GroupedReductionOp"; } - Expr* shallowCopy() const override; - //! Number of expressions grouped horizontally. It does not reflect //! iteration grouping. size_t numExprs() const { - return reduction_op_types_.size(); + return getReductionOpTypes().size(); } - const std::vector& initVals() const { - return init_vals_; + std::vector initVals() const { + auto size = numExprs(); + std::vector result; + result.reserve(size); + for (auto i : c10::irange(2, 2 + size)) { + result.emplace_back(attribute(i)->as()); + } + return result; } Val* initVal(size_t index) const { - return init_vals_.at(index); + return attributeVal(2 + index); } const std::vector& getReductionOpTypes() const { - return reduction_op_types_; + return attribute(0)->as>>()->value; } BinaryOpType getReductionOpType(size_t index) const { - return reduction_op_types_.at(index); + return getReductionOpTypes().at(index); } bool isAllreduce() const { - return is_allreduce_; + return attribute(1)->as>()->value; } //! Return the index of the corresponding reduction expression for //! a given output val. int getExprIndexOfOutput(Val* output_val) const; - - bool sameAs(const Statement* other) const override; - - private: - //! Reduction ops of grouped reductions - const std::vector reduction_op_types_; - //! Initial values of grouped reductions - const std::vector init_vals_; - //! True if using the fused reduction kernel - bool is_allreduce_ = false; }; //! Average, variance and N (count) vals for Welford @@ -697,6 +645,8 @@ class TORCH_CUDA_CU_API WelfordTriplet { //! Welford Scan operation. class TORCH_CUDA_CU_API WelfordOp : public Expr { public: + using Expr::Expr; + WelfordOp( IrBuilderPasskey, const WelfordTriplet& output, @@ -717,70 +667,66 @@ class TORCH_CUDA_CU_API WelfordOp : public Expr { Val* init_N, bool is_fused = false); - WelfordOp(const WelfordOp* src, IrCloner* ir_cloner); + NVFUSER_DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "WelfordOp"; } - Expr* shallowCopy() const override; - Val* out() const { - return output().avg(); + return outputTriplet().avg(); } Val* in() const { - return input().avg(); + return inputTriplet().avg(); } - bool sameAs(const Statement* const other) const override; - - const WelfordTriplet& output() const { - return output_; + WelfordTriplet outputTriplet() const { + return WelfordTriplet(outAvg(), outVar(), outN()); } Val* outAvg() const { - return output().avg(); + return output(0); } Val* outVar() const { - return output().var(); + return output(1); } Val* outN() const { - return output().N(); + return output(2); } - const WelfordTriplet& input() const { - return input_; + WelfordTriplet inputTriplet() const { + return WelfordTriplet(inAvg(), inVar(), inN()); } Val* inAvg() const { - return input().avg(); + return input(0); } Val* inVar() const { - return input().var(); + return input(1); } Val* inN() const { - return input().N(); + return input(2); } - const WelfordTriplet& init() const { - return init_; + WelfordTriplet initTriplet() const { + return WelfordTriplet(initAvg(), initVar(), initN()); } Val* initAvg() const { - return init().avg(); + return attributeVal(0); } Val* initVar() const { - return init().var(); + return attributeVal(1); } Val* initN() const { - return init().N(); + return attributeVal(2); } bool singleValue() const { @@ -791,25 +737,21 @@ class TORCH_CUDA_CU_API WelfordOp : public Expr { return !initN()->isZeroInt(); } + //! True if using the fused reduction kernel (not implemented yet) bool isAllreduce() const { - return is_allreduce_; + return attribute(3)->as>()->value; } std::vector getInitVals() const; //! Return the init val for an output val Val* getInitValOfOutput(Val* output_val) const; - - private: - const WelfordTriplet output_; - const WelfordTriplet input_; - const WelfordTriplet init_; - //! True if using the fused reduction kernel (not implemented yet) - bool is_allreduce_ = false; }; class TORCH_CUDA_CU_API GroupedWelfordOp : public Expr { public: + using Expr::Expr; + GroupedWelfordOp( IrBuilderPasskey, std::vector output_vals, @@ -817,14 +759,12 @@ class TORCH_CUDA_CU_API GroupedWelfordOp : public Expr { std::vector init_vals, bool is_allreduce = false); - GroupedWelfordOp(const GroupedWelfordOp* src, IrCloner* ir_cloner); + NVFUSER_DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "GroupedWelfordOp"; } - Expr* shallowCopy() const override; - //! Number of expressions grouped horizontally. It does not reflect //! iteration grouping. As horizontal grouping is not supported, //! this always returns 1. @@ -840,54 +780,70 @@ class TORCH_CUDA_CU_API GroupedWelfordOp : public Expr { return inAvg(index); } - bool sameAs(const Statement* const other) const override; - - const std::vector& outputVals() const { - return output_vals_; + std::vector outputVals() const { + std::vector result; + auto size = outputs().size() / 3; + result.reserve(size); + for (auto i : c10::irange(size)) { + result.emplace_back(outAvg(i), outVar(i), outN(i)); + } + return result; } - const std::vector& inputVals() const { - return input_vals_; + std::vector inputVals() const { + std::vector result; + auto size = inputs().size() / 3; + result.reserve(size); + for (auto i : c10::irange(size)) { + result.emplace_back(inAvg(i), inVar(i), inN(i)); + } + return result; } - const std::vector& initVals() const { - return init_vals_; + std::vector initVals() const { + std::vector result; + auto size = inputs().size() / 3; + result.reserve(size); + for (auto i : c10::irange(size)) { + result.emplace_back(initAvg(i), initVar(i), initN(i)); + } + return result; } Val* outAvg(size_t index) const { - return outputVals().at(index).avg(); + return output(index * 3); } Val* outVar(size_t index) const { - return outputVals().at(index).var(); + return output(index * 3 + 1); } Val* outN(size_t index) const { - return outputVals().at(index).N(); + return output(index * 3 + 2); } Val* inAvg(size_t index) const { - return inputVals().at(index).avg(); + return input(index * 3); } Val* inVar(size_t index) const { - return inputVals().at(index).var(); + return input(index * 3 + 1); } Val* inN(size_t index) const { - return inputVals().at(index).N(); + return input(index * 3 + 2); } Val* initAvg(size_t index) const { - return initVals().at(index).avg(); + return attributeVal(1 + index * 3); } Val* initVar(size_t index) const { - return initVals().at(index).var(); + return attributeVal(2 + index * 3); } Val* initN(size_t index) const { - return initVals().at(index).N(); + return attributeVal(3 + index * 3); } //! Return the index of the corresponding welford expression for @@ -906,15 +862,8 @@ class TORCH_CUDA_CU_API GroupedWelfordOp : public Expr { } bool isAllreduce() const { - return is_allreduce_; + return attribute(0)->as>()->value; } - - private: - const std::vector output_vals_; - const std::vector input_vals_; - const std::vector init_vals_; - //! True if using the fused reduction kernel - bool is_allreduce_ = false; }; //! Fused Matmul operation @@ -936,6 +885,8 @@ class TORCH_CUDA_CU_API MmaOp : public Expr { } }; + using Expr::Expr; + MmaOp(IrBuilderPasskey, Val* out, Val* in_a, Val* in_b, Val* init); MmaOp( @@ -946,181 +897,104 @@ class TORCH_CUDA_CU_API MmaOp : public Expr { Val* init, OptionsInMma options); - MmaOp(const MmaOp* src, IrCloner* ir_cloner); + NVFUSER_DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "MmaOp"; } - Expr* shallowCopy() const override; - Val* out() const { - return out_; + return output(0); } Val* inA() const { - return in_a_; + return input(0); } Val* inB() const { - return in_b_; + return input(1); } Val* init() const { - return init_; + return attributeVal(0); } const auto& options() const { - TORCH_INTERNAL_ASSERT(options_.has_value(), "MmaOp not configured:", this); - return options_.value(); + return attribute(1)->as>()->value; } - bool sameAs(const Statement* const other) const override; - auto accStride() const { - TORCH_INTERNAL_ASSERT(options_.has_value(), "MmaOp not configured:", this); - return options_->accumulator_stride; + return options().accumulator_stride; } - void configureOptions(MmaOptions options) { - options_ = OptionsInMma(); - TORCH_INTERNAL_ASSERT( - options.macro != MmaOptions::MacroType::NoMMA, - "Un-configured mma type from options."); - TORCH_INTERNAL_ASSERT( - options.accumulator_stride > 0, "Un-configured accumulator stride."); - options_->accumulator_stride = options.accumulator_stride; - options_->macro = options.macro; - options_->operand_layout = options.operand_layout; - } - - private: - Val* const out_ = nullptr; - Val* const in_a_ = nullptr; - Val* const in_b_ = nullptr; - Val* const init_ = nullptr; - c10::optional options_ = c10::nullopt; + void configureOptions(MmaOptions options); }; class TORCH_CUDA_CU_API TransposeOp : public Expr { public: + using Expr::Expr; + TransposeOp( IrBuilderPasskey, TensorView* out, TensorView* in, std::vector new2old); - TransposeOp(const TransposeOp* src, IrCloner* ir_cloner); + NVFUSER_DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "TransposeOp"; } - Expr* shallowCopy() const override; - TensorView* out() const { - return out_; + return output(0)->as(); } TensorView* in() const { - return in_; + return input(0)->as(); } const std::vector& new2old() const { - return new2old_; + return attribute(0)->as>>()->value; } std::vector old2new() const; - - private: - TensorView* const out_ = nullptr; - TensorView* const in_ = nullptr; - const std::vector new2old_; }; class TORCH_CUDA_CU_API ExpandOp : public Expr { public: + using Expr::Expr; + ExpandOp( IrBuilderPasskey, TensorView* out, TensorView* in, std::vector _expanded_extents); - ExpandOp(const ExpandOp* src, IrCloner* ir_cloner); + NVFUSER_DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "ExpandOp"; } - Expr* shallowCopy() const override; - TensorView* out() const { - return out_; + return output(0)->as(); } TensorView* in() const { - return in_; - } - - const std::vector& expanded_extents() const { - return expanded_extents_; - } - - private: - TensorView* const out_ = nullptr; - TensorView* const in_ = nullptr; - std::vector expanded_extents_; -}; - -class TORCH_CUDA_CU_API TernaryOp : public Expr { - public: - TernaryOp( - IrBuilderPasskey, - TernaryOpType type, - Val* out, - Val* in1, - Val* in2, - Val* in3); - - TernaryOp(const TernaryOp* src, IrCloner* ir_cloner); - - virtual const char* getOpString() const override { - return "TernaryOp"; - } - - Expr* shallowCopy() const override; - - Val* out() const { - return out_; + return input(0)->as(); } - Val* in1() const { - return in1_; + std::vector expanded_extents() const { + return {inputs().begin() + 1, inputs().end()}; } - Val* in2() const { - return in2_; - } - Val* in3() const { - return in3_; - } - - TernaryOpType getTernaryOpType() const { - return ternary_op_type_; - } - - bool sameAs(const Statement* other) const override; - - private: - const TernaryOpType ternary_op_type_; - Val* const out_ = nullptr; - Val* const in1_ = nullptr; - Val* const in2_ = nullptr; - Val* const in3_ = nullptr; }; //! Shift class TORCH_CUDA_CU_API ShiftOp : public Expr { public: + using Expr::Expr; + //! \param out //! \param in //! \param offsets @@ -1131,54 +1005,45 @@ class TORCH_CUDA_CU_API ShiftOp : public Expr { std::vector offsets, std::vector pad_width); - ShiftOp(const ShiftOp* src, IrCloner* ir_cloner); + NVFUSER_DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "ShiftOp"; } - Expr* shallowCopy() const override; - Val* out() const { - return out_; + return output(0); } Val* in() const { - return in_; + return input(0); } int offset(size_t dim) const { - return offsets_.at(dim); + return offsets().at(dim); } + //! Each of the root axes is shifted by the corresponding value of + //! offsets. The sign of each value indicates the direction of shifting. const std::vector& offsets() const { - return offsets_; + return attribute(0)->as>>()->value; } const std::vector& padWidth() const { - return pad_width_; + return attribute(1)->as>>()->value; } bool hasPadding() const { - return std::any_of(pad_width_.begin(), pad_width_.end(), [](const auto p) { + return std::any_of(padWidth().begin(), padWidth().end(), [](const auto p) { return p > 0; }); } - - bool sameAs(const Statement* other) const override; - - private: - Val* const out_ = nullptr; - Val* const in_ = nullptr; - //! Each of the root axes is shifted by the corresponding value of - //! offsets_. The sign of each value indicates the direction of - //! shifting. - const std::vector offsets_; - const std::vector pad_width_; }; //! Gather a window around each element. class TORCH_CUDA_CU_API GatherOp : public Expr { public: + using Expr::Expr; + GatherOp( IrBuilderPasskey, Val* out, @@ -1186,51 +1051,43 @@ class TORCH_CUDA_CU_API GatherOp : public Expr { std::vector window_shape, std::vector> pad_width); - GatherOp(const GatherOp* src, IrCloner* ir_cloner); + NVFUSER_DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "GatherOp"; } - Expr* shallowCopy() const override; - Val* out() const { - return out_; + return output(0); } Val* in() const { - return in_; + return input(0); } + //! Shape of a window gathered for each element. const auto& windowShape() const { - return window_shape_; + return attribute(0)->as>>()->value; } //! Returns the gather axis that corresponds to an input axis int gatherAxis(int axis) const; + //! The size of zero-padding of each axis. const auto& padWidth() const { - return pad_width_; + return attribute(1)->as>>>()->value; } bool hasPadding() const { - return std::any_of(pad_width_.begin(), pad_width_.end(), [](const auto& p) { + return std::any_of(padWidth().begin(), padWidth().end(), [](const auto& p) { return p[0] > 0 || p[1] > 0; }); } - - bool sameAs(const Statement* other) const override; - - private: - Val* const out_ = nullptr; - Val* const in_ = nullptr; - //! Shape of a window gathered for each element. - std::vector window_shape_; - //! The size of zero-padding of each axis. - std::vector> pad_width_; }; class TORCH_CUDA_CU_API ViewAsScalar : public Expr { public: + using Expr::Expr; + ViewAsScalar( IrBuilderPasskey, Val* out, @@ -1238,64 +1095,50 @@ class TORCH_CUDA_CU_API ViewAsScalar : public Expr { IterDomain* vector_id, Val* index = nullptr); - ViewAsScalar(const ViewAsScalar* src, IrCloner* ir_cloner); + NVFUSER_DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "ViewAsScalar"; } - Expr* shallowCopy() const override; - Val* out() const { - return out_; + return output(0); } Val* in() const { - return in_; + return input(0); } + // The IterDomain of type VectorComponent newly appended to the output IterDomain* vector_id() const { - return vector_id_; + return attribute(0)->as(); } + // The index that vector_id_ is lowered into Val* index() const { - return index_; + return attributeVal(1); } - - private: - Val* const out_ = nullptr; - Val* const in_ = nullptr; - - // The IterDomain of type VectorComponent newly appended to the output - IterDomain* vector_id_ = nullptr; - - // The index that vector_id_ is lowered into - Val* index_ = nullptr; }; class TORCH_CUDA_CU_API ViewOp : public Expr { public: - ViewOp(IrBuilderPasskey, TensorView* out, TensorView* in); + using Expr::Expr; - ViewOp(const ViewOp* src, IrCloner* ir_cloner); + ViewOp(IrBuilderPasskey, Val* out, Val* in); + + NVFUSER_DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "ViewOp"; } - Expr* shallowCopy() const override; - - TensorView* out() const { - return out_; + Val* out() const { + return output(0); } - TensorView* in() const { - return in_; + Val* in() const { + return input(0); } - - private: - TensorView* const out_ = nullptr; - TensorView* const in_ = nullptr; }; //! This operator explicitly models data movement between @@ -1306,32 +1149,27 @@ class TORCH_CUDA_CU_API ViewOp : public Expr { //! accelerated memory ops, i.e. ldmatrix, cp.async and more to come. class TORCH_CUDA_CU_API LoadStoreOp : public Expr { public: + using Expr::Expr; + LoadStoreOp(IrBuilderPasskey, LoadStoreOpType op_type, Val* out, Val* in); - LoadStoreOp(const LoadStoreOp* src, IrCloner* ir_cloner); + NVFUSER_DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "LoadStoreOp"; } - Expr* shallowCopy() const override; - Val* out() const { - return out_; + return output(0); } Val* in() const { - return in_; + return input(0); } LoadStoreOpType opType() const { - return load_store_type_; + return attribute(0)->as>()->value; } - - private: - LoadStoreOpType load_store_type_ = LoadStoreOpType::LdMatrix; - Val* const out_ = nullptr; - Val* const in_ = nullptr; }; // Convenience utility to initialize IterDomain's without having to sort through @@ -1411,6 +1249,8 @@ class TORCH_CUDA_CU_API IterDomain : public Val { IterDomain(const IterDomain* src, IrCloner* ir_cloner); + NVFUSER_DECLARE_CLONE + bool sameAs(const Statement* other) const override; //! Returns a new IterDomain matching properties of this @@ -1718,6 +1558,8 @@ class TORCH_CUDA_CU_API TensorDomain : public Val { TensorDomain(const TensorDomain* src, IrCloner* ir_cloner); + NVFUSER_DECLARE_CLONE + bool operator==(const TensorDomain& other) const; bool operator!=(const TensorDomain& other) const { return !(*this == other); @@ -1906,6 +1748,8 @@ class TORCH_CUDA_CU_API TensorDomain : public Val { //! remainer or outside. class TORCH_CUDA_CU_API Split : public Expr { public: + using Expr::Expr; + // start_offset and stop_offset are used to express partial // split. Only the partial domain from start_offset to stop_offset // is split and the outer sub-regions are ignored. Note that both @@ -1921,58 +1765,45 @@ class TORCH_CUDA_CU_API Split : public Expr { Val* start_offset = nullptr, Val* stop_offset = nullptr); - Split(const Split* src, IrCloner* ir_cloner); + NVFUSER_DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "Split"; } - Expr* shallowCopy() const override; - IterDomain* outer() const { - return outer_; + return output(0)->as(); } IterDomain* inner() const { - return inner_; + return output(1)->as(); } IterDomain* in() const { - return in_; + return input(0)->as(); } Val* factor() const { - return factor_; + return attributeVal(0); } bool innerSplit() const { - return inner_split_; + return attribute(1)->as>()->value; } + //! Start position of the input domain. Non-zero means partial + //! split. Elements until this offset are ignored. Val* startOffset() const { - TORCH_INTERNAL_ASSERT(start_offset_ != nullptr); - return start_offset_; + TORCH_INTERNAL_ASSERT(attributeVal(2) != nullptr); + return attributeVal(2); } + //! Offset from extent of the input domain. Non-zero means partial + //! split. Elements after this offset are ignored. Val* stopOffset() const { - TORCH_INTERNAL_ASSERT(stop_offset_ != nullptr); - return stop_offset_; + TORCH_INTERNAL_ASSERT(attributeVal(3) != nullptr); + return attributeVal(3); } //! Utility function to compute the split extent. static Val* extent(Val* in_extent, Val* start_offset, Val* stop_offset); - - bool sameAs(const Statement* other) const override; - - private: - IterDomain* const outer_ = nullptr; - IterDomain* const inner_ = nullptr; - IterDomain* const in_ = nullptr; - Val* const factor_ = nullptr; - bool inner_split_ = true; - //! Start position of the input domain. Non-zero means partial - //! split. Elements until this offset are ignored. - Val* const start_offset_ = nullptr; - //! Offset from extent of the input domain. Non-zero means partial - //! split. Elements after this offset are ignored. - Val* const stop_offset_ = nullptr; }; //! Merge the IterDomains outer and inner into one domain, outer and inner @@ -1981,41 +1812,36 @@ class TORCH_CUDA_CU_API Split : public Expr { //! strategy if there is one class TORCH_CUDA_CU_API Merge : public Expr { public: + using Expr::Expr; + Merge( IrBuilderPasskey, IterDomain* out, IterDomain* outer, IterDomain* inner); - Merge(const Merge* src, IrCloner* ir_cloner); + NVFUSER_DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "Merge"; } - Expr* shallowCopy() const override; - IterDomain* out() const { - return out_; + return output(0)->as(); } IterDomain* outer() const { - return outer_; + return input(0)->as(); } IterDomain* inner() const { - return inner_; + return input(1)->as(); } - - bool sameAs(const Statement* other) const override; - - private: - IterDomain* const out_ = nullptr; - IterDomain* const outer_ = nullptr; - IterDomain* const inner_ = nullptr; }; //! Applies 2D swizzles on a rectangular tile defined by 2 iterdomains. class TORCH_CUDA_CU_API Swizzle2D : public Expr { public: + using Expr::Expr; + Swizzle2D( IrBuilderPasskey, IterDomain* out_x, @@ -2025,53 +1851,36 @@ class TORCH_CUDA_CU_API Swizzle2D : public Expr { Swizzle2DType swizzle_type = Swizzle2DType::NoSwizzle, SwizzleMode swizzle_mode = SwizzleMode::Data); - Swizzle2D(const Swizzle2D* src, IrCloner* ir_cloner); + NVFUSER_DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "Swizzle2D"; } - Expr* shallowCopy() const override; - + // Output iterdomain pair corresponding + // to the original input iterdomain pair. IterDomain* outX() const { - return out_x_; + return output(0)->as(); } IterDomain* outY() const { - return out_y_; + return output(1)->as(); } + // Input iterdomain pair. IterDomain* inX() const { - return in_x_; + return input(0)->as(); } IterDomain* inY() const { - return in_y_; - } - - auto swizzleType() const { - return swizzle_type_; + return input(1)->as(); } - auto swizzleMode() const { - return swizzle_mode_; - } - - bool sameAs(const Statement* other) const override; - - private: - // Output iterdomain pair corresponding - // to the original input iterdomain pair. - IterDomain* const out_x_ = nullptr; - IterDomain* const out_y_ = nullptr; - - // Input iterdomain pair. - IterDomain* const in_x_ = nullptr; - IterDomain* const in_y_ = nullptr; - // The type of predefined 1-to-1 functions // used for swizzling math. - Swizzle2DType swizzle_type_ = Swizzle2DType::NoSwizzle; + auto swizzleType() const { + return attribute(0)->as>()->value; + } // Swizzle mode of this swizzle instance. // [Note on swizzle mode] @@ -2114,7 +1923,9 @@ class TORCH_CUDA_CU_API Swizzle2D : public Expr { // } // TODO: Loop swizzles eventually will be piped through in all mappings // and replay of the fusion IR infrastructure. - SwizzleMode swizzle_mode_ = SwizzleMode::Data; + auto swizzleMode() const { + return attribute(1)->as>()->value; + } }; //! Integer value which has a special name @@ -2131,6 +1942,8 @@ class TORCH_CUDA_CU_API NamedScalar : public Val { NamedScalar(const NamedScalar* src, IrCloner* ir_cloner); + NVFUSER_DECLARE_CLONE + const std::string& name() const { return name_; } diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index e2a809b297c9..09bec9009311 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -89,6 +89,8 @@ Bool::Bool(IrBuilderPasskey passkey, c10::optional value) Bool::Bool(const Bool* src, IrCloner* ir_cloner) : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {} +NVFUSER_DEFINE_CLONE(Bool) + bool Bool::sameAs(const Statement* other) const { if (this == other) { return true; @@ -116,6 +118,8 @@ Int::Int(IrBuilderPasskey passkey, c10::optional value) Int::Int(const Int* src, IrCloner* ir_cloner) : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {} +NVFUSER_DEFINE_CLONE(Int) + bool Int::sameAs(const Statement* other) const { if (this == other) { return true; @@ -147,6 +151,8 @@ ComplexDouble::ComplexDouble( ComplexDouble::ComplexDouble(const ComplexDouble* src, IrCloner* ir_cloner) : Val(src, ir_cloner), maybe_value_(src->maybe_value_) {} +NVFUSER_DEFINE_CLONE(ComplexDouble) + bool ComplexDouble::sameAs(const Statement* other) const { if (this == other) { return true; @@ -165,7 +171,7 @@ FullOp::FullOp( Val* out, Val* fill_value, DataType dtype) - : Expr(passkey), dtype_(dtype), fill_value_(fill_value) { + : Expr(passkey) { if (out->isA()) { auto tv_root = out->as()->getRootDomain(); for (auto id : tv_root) { @@ -174,32 +180,11 @@ FullOp::FullOp( } addInput(fill_value); addOutput(out); + addAttribute( + IrBuilder::create>(passkey.ir_container_, dtype)); } -FullOp::FullOp(const FullOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - dtype_(src->dtype()), - fill_value_(ir_cloner->clone(src->fill_value_)) {} - -Expr* FullOp::shallowCopy() const { - auto result = IrBuilder::create(output(0), fill_value_, dtype_); - result->copyPredicatesFrom(this); - return result; -} - -bool FullOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - if (dtype_ != other_op->dtype_) { - return false; - } - return Expr::sameAs(other); -} +NVFUSER_DEFINE_CLONE_AND_CREATE(FullOp) SelectOp::SelectOp( IrBuilderPasskey passkey, @@ -207,35 +192,15 @@ SelectOp::SelectOp( Val* in, IterDomain* select_id, Val* index) - : Expr(passkey), select_id_(select_id) { + : Expr(passkey) { addInput(in); addInput(index); addOutput(out); + addAttribute(select_id); + addAttribute(index); } -SelectOp::SelectOp(const SelectOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), select_id_(ir_cloner->clone(src->select_id_)) {} - -Expr* SelectOp::shallowCopy() const { - auto result = - IrBuilder::create(output(0), input(0), select_id_, input(1)); - result->copyPredicatesFrom(this); - return result; -} - -bool SelectOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - if (!select_id_->sameAs(other_op->select_id_)) { - return false; - } - return Expr::sameAs(other); -} +NVFUSER_DEFINE_CLONE_AND_CREATE(SelectOp) ARangeOp::ARangeOp( IrBuilderPasskey passkey, @@ -245,62 +210,17 @@ ARangeOp::ARangeOp( Val* step, DataType dtype, Val* linear_index) - : Expr(passkey), - dtype_(dtype), - start_(start), - end_(end), - step_(step), - linear_index_(linear_index) { + : Expr(passkey) { addInput(start); addInput(end); addInput(step); addOutput(out); + addAttribute( + IrBuilder::create>(passkey.ir_container_, dtype)); + addAttribute(linear_index); } -ARangeOp::ARangeOp(const ARangeOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - dtype_(src->dtype()), - start_(ir_cloner->clone(src->start_)), - end_(ir_cloner->clone(src->end_)), - step_(ir_cloner->clone(src->step_)), - linear_index_(ir_cloner->clone(src->linear_index_)) {} - -Expr* ARangeOp::shallowCopy() const { - auto result = IrBuilder::create( - output(0), start_, end_, step_, dtype_, linear_index_); - result->copyPredicatesFrom(this); - return result; -} - -bool ARangeOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - if (dtype_ != other_op->dtype_) { - return false; - } - if (!start_->sameAs(other_op->start_)) { - return false; - } - if (!end_->sameAs(other_op->end_)) { - return false; - } - if (!step_->sameAs(other_op->step_)) { - return false; - } - if ((linear_index_ == nullptr) != (other_op->linear_index_ == nullptr)) { - return false; - } - if ((linear_index_ != nullptr) && - !linear_index_->sameAs(other_op->linear_index_)) { - return false; - } - return Expr::sameAs(other); -} +NVFUSER_DEFINE_CLONE_AND_CREATE(ARangeOp) EyeOp::EyeOp( IrBuilderPasskey passkey, @@ -308,7 +228,7 @@ EyeOp::EyeOp( DataType dtype, Val* index1, Val* index2) - : Expr(passkey), dtype_(dtype), index1_(index1), index2_(index2) { + : Expr(passkey) { if (out->isA()) { addInput(out->as()->getRootDomain()[0]->extent()); if (out->as()->getRootDomain()[1] != @@ -317,82 +237,23 @@ EyeOp::EyeOp( } } addOutput(out); + addAttribute( + IrBuilder::create>(passkey.ir_container_, dtype)); + addAttribute(index1); + addAttribute(index2); } -EyeOp::EyeOp(const EyeOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - dtype_(src->dtype_), - index1_(ir_cloner->clone(src->index1_)), - index2_(ir_cloner->clone(src->index2_)) {} +NVFUSER_DEFINE_CLONE_AND_CREATE(EyeOp) -Expr* EyeOp::shallowCopy() const { - auto result = IrBuilder::create(output(0), dtype_, index1_, index2_); - result->copyPredicatesFrom(this); - return result; -} - -bool EyeOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - if (dtype_ != other_op->dtype_) { - return false; - } - if ((index1_ == nullptr) != (other_op->index1_ == nullptr)) { - return false; - } - if ((index2_ == nullptr) != (other_op->index2_ == nullptr)) { - return false; - } - if ((index1_ != nullptr) && !index1_->sameAs(other_op->index1_)) { - return false; - } - if ((index2_ != nullptr) && !index2_->sameAs(other_op->index2_)) { - return false; - } - return Expr::sameAs(other); -} - -UnaryOp::UnaryOp( - IrBuilderPasskey passkey, - UnaryOpType type, - Val* out, - Val* in, - int rng_offset) - : Expr(passkey), unary_op_type_{type}, out_{out}, in_{in} { +UnaryOp::UnaryOp(IrBuilderPasskey passkey, UnaryOpType type, Val* out, Val* in) + : Expr(passkey) { addOutput(out); addInput(in); + addAttribute( + IrBuilder::create>(passkey.ir_container_, type)); } -UnaryOp::UnaryOp(const UnaryOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - unary_op_type_(src->unary_op_type_), - out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)) {} - -Expr* UnaryOp::shallowCopy() const { - auto result = IrBuilder::create(unary_op_type_, out_, in_); - result->copyPredicatesFrom(this); - return result; -} - -bool UnaryOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - if (getUnaryOpType() != other_op->getUnaryOpType()) { - return false; - } - return Expr::sameAs(other); -} +NVFUSER_DEFINE_CLONE_AND_CREATE(UnaryOp) BinaryOp::BinaryOp( IrBuilderPasskey passkey, @@ -400,38 +261,15 @@ BinaryOp::BinaryOp( Val* out, Val* lhs, Val* rhs) - : Expr(passkey), binary_op_type_{type}, out_{out}, lhs_{lhs}, rhs_{rhs} { + : Expr(passkey) { addOutput(out); addInput(lhs); addInput(rhs); + addAttribute( + IrBuilder::create>(passkey.ir_container_, type)); } -BinaryOp::BinaryOp(const BinaryOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - binary_op_type_(src->binary_op_type_), - out_(ir_cloner->clone(src->out_)), - lhs_(ir_cloner->clone(src->lhs_)), - rhs_(ir_cloner->clone(src->rhs_)) {} - -Expr* BinaryOp::shallowCopy() const { - auto result = IrBuilder::create(binary_op_type_, out_, lhs_, rhs_); - result->copyPredicatesFrom(this); - return result; -} - -bool BinaryOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - if (getBinaryOpType() != other_op->getBinaryOpType()) { - return false; - } - return Expr::sameAs(other); -} +NVFUSER_DEFINE_CLONE_AND_CREATE(BinaryOp) TernaryOp::TernaryOp( IrBuilderPasskey passkey, @@ -440,46 +278,16 @@ TernaryOp::TernaryOp( Val* in1, Val* in2, Val* in3) - : Expr(passkey), - ternary_op_type_{type}, - out_{out}, - in1_{in1}, - in2_{in2}, - in3_{in3} { + : Expr(passkey) { addOutput(out); addInput(in1); addInput(in2); addInput(in3); + addAttribute( + IrBuilder::create>(passkey.ir_container_, type)); } -TernaryOp::TernaryOp(const TernaryOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - ternary_op_type_(src->ternary_op_type_), - out_(ir_cloner->clone(src->out_)), - in1_(ir_cloner->clone(src->in1_)), - in2_(ir_cloner->clone(src->in2_)), - in3_(ir_cloner->clone(src->in3_)) {} - -Expr* TernaryOp::shallowCopy() const { - auto result = - IrBuilder::create(ternary_op_type_, out_, in1_, in2_, in3_); - result->copyPredicatesFrom(this); - return result; -} - -bool TernaryOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - if (getTernaryOpType() != other_op->getTernaryOpType()) { - return false; - } - return Expr::sameAs(other); -} +NVFUSER_DEFINE_CLONE_AND_CREATE(TernaryOp) RNGOp::RNGOp( IrBuilderPasskey passkey, @@ -489,74 +297,31 @@ RNGOp::RNGOp( std::vector parameters, int rng_offset, Val* philox_index) - : Expr(passkey), - rng_op_type_(type), - dtype_(dtype), - parameters_(std::move(parameters)), - rng_offset_(rng_offset), - philox_index_(philox_index) { - if (out->isA()) { - for (auto id : out->as()->getRootDomain()) { - shape_.emplace_back(id->extent()); + : Expr(passkey) { + if (auto tv_out = dynamic_cast(out)) { + for (auto id : tv_out->getRootDomain()) { + TORCH_CHECK(!id->isReduction(), "Output of RNGOp can not have reduction"); + addInput(id->extent()); } } - for (auto v : shape_) { - addInput(v); - } - for (auto v : parameters_) { + for (auto v : parameters) { addInput(v); } addOutput(out); + RNGOp::Attributes attr{type, dtype, rng_offset}; + addAttribute(IrBuilder::create>( + passkey.ir_container_, attr)); + addAttribute(philox_index); } -RNGOp::RNGOp(const RNGOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - rng_op_type_(src->rng_op_type_), - dtype_(src->dtype()), - parameters_(ir_cloner->clone(src->parameters_)), - rng_offset_(src->rng_offset_), - philox_index_(ir_cloner->clone(src->philox_index_)) {} - -Expr* RNGOp::shallowCopy() const { - auto result = IrBuilder::create( - rng_op_type_, output(0), dtype_, parameters_, rng_offset_, philox_index_); - result->copyPredicatesFrom(this); - return result; -} +NVFUSER_DEFINE_CLONE_AND_CREATE(RNGOp) -bool RNGOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - if (getRNGOpType() != other_op->getRNGOpType()) { - return false; - } - if (dtype_ != other_op->dtype_) { - return false; - } - if (parameters_.size() != other_op->parameters_.size()) { - return false; - } - for (auto i : c10::irange(parameters_.size())) { - if (!parameters_[i]->sameAs(other_op->parameters_[i])) { - return false; - } - } - if (getRNGOffset() != other_op->getRNGOffset()) { - return false; +size_t RNGOp::getOutputDims() const { + size_t ndims = 0; + if (auto tv_out = dynamic_cast(output(0))) { + ndims = tv_out->getRootDomain().size(); } - if ((philox_index_ == nullptr) != (other_op->philox_index_ == nullptr)) { - return false; - } - if ((philox_index_ != nullptr) && - !philox_index_->sameAs(other_op->philox_index_)) { - return false; - } - return Expr::sameAs(other); + return ndims; } BroadcastOp::BroadcastOp( @@ -564,10 +329,7 @@ BroadcastOp::BroadcastOp( Val* out, Val* in, std::vector is_broadcast_dims) - : Expr(passkey), - out_(out), - in_(in), - is_broadcast_dims_(std::move(is_broadcast_dims)) { + : Expr(passkey) { auto out_type = out->getValType().value(); auto in_type = in->getValType().value(); @@ -578,6 +340,8 @@ BroadcastOp::BroadcastOp( addOutput(out); addInput(in); + addAttribute(IrBuilder::create>>( + passkey.ir_container_, std::move(is_broadcast_dims))); if (!out->isA() || !in->isA()) { return; @@ -588,13 +352,13 @@ BroadcastOp::BroadcastOp( auto in_dom = TensorDomain::noReductions(in_tv->getMaybeRFactorDomain()); auto& out_dom = out_tv->getRootDomain(); TORCH_INTERNAL_ASSERT( - is_broadcast_dims_.size() == out_dom.size(), + is_broadcast_dims.size() == out_dom.size(), "The dimensions of output tensor and does not match with is_broadcast_dims"); - auto out_size = is_broadcast_dims_.size(); + auto out_size = is_broadcast_dims.size(); auto num_new_broadcasts = 0; for (const auto i : c10::irange(out_size)) { - if (is_broadcast_dims_[i]) { + if (is_broadcast_dims[i]) { num_new_broadcasts++; auto id = out_dom[i]; TORCH_INTERNAL_ASSERT( @@ -618,41 +382,14 @@ BroadcastOp::BroadcastOp( "The dimensions of output tensor and does not match with is_broadcast_dims and input tensor"); } -BroadcastOp::BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)), - is_broadcast_dims_(src->is_broadcast_dims_) {} - -Expr* BroadcastOp::shallowCopy() const { - auto result = IrBuilder::create(out_, in_, is_broadcast_dims_); - result->copyPredicatesFrom(this); - return result; -} - -bool BroadcastOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - if (getBroadcastDimFlags() != other_op->getBroadcastDimFlags()) { - return false; - } - return Expr::sameAs(other); -} +NVFUSER_DEFINE_CLONE_AND_CREATE(BroadcastOp) SqueezeOp::SqueezeOp( IrBuilderPasskey passkey, Val* out, Val* in, std::vector is_squeeze_dims) - : Expr(passkey), - out_(out), - in_(in), - is_squeeze_dims_(std::move(is_squeeze_dims)) { + : Expr(passkey) { auto out_type = out->getValType().value(); auto in_type = in->getValType().value(); @@ -663,6 +400,8 @@ SqueezeOp::SqueezeOp( addOutput(out); addInput(in); + addAttribute(IrBuilder::create>>( + passkey.ir_container_, std::move(is_squeeze_dims))); if (!out->isA() || !in->isA()) { return; @@ -673,13 +412,13 @@ SqueezeOp::SqueezeOp( auto in_dom = TensorDomain::noReductions(in_tv->getMaybeRFactorDomain()); auto& out_dom = out_tv->getRootDomain(); TORCH_INTERNAL_ASSERT( - is_squeeze_dims_.size() == in_dom.size(), + is_squeeze_dims.size() == in_dom.size(), "The dimensions of input tensor and does not match with is_squeeze_dims"); - auto in_size = is_squeeze_dims_.size(); + auto in_size = is_squeeze_dims.size(); auto num_removed_broadcasts = 0; - for (const auto i : c10::irange(is_squeeze_dims_.size())) { - if (is_squeeze_dims_[i]) { + for (const auto i : c10::irange(is_squeeze_dims.size())) { + if (is_squeeze_dims[i]) { num_removed_broadcasts++; auto id = in_dom[i]; TORCH_INTERNAL_ASSERT( @@ -701,31 +440,7 @@ SqueezeOp::SqueezeOp( "The dimensions of output tensor and does not match with is_squeeze_dims and input tensor"); } -SqueezeOp::SqueezeOp(const SqueezeOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)), - is_squeeze_dims_(src->is_squeeze_dims_) {} - -Expr* SqueezeOp::shallowCopy() const { - auto result = IrBuilder::create(out_, in_, is_squeeze_dims_); - result->copyPredicatesFrom(this); - return result; -} - -bool SqueezeOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - if (getSqueezeDimFlags() != other_op->getSqueezeDimFlags()) { - return false; - } - return Expr::sameAs(other); -} +NVFUSER_DEFINE_CLONE_AND_CREATE(SqueezeOp) ReductionOp::ReductionOp( IrBuilderPasskey passkey, @@ -734,12 +449,7 @@ ReductionOp::ReductionOp( Val* out, Val* in, bool is_allreduce) - : Expr(passkey), - reduction_op_type_(reduction_op_type), - init_(init), - out_(out), - in_(in), - is_allreduce_(is_allreduce) { + : Expr(passkey) { TORCH_CHECK( out->getValType().value() == ValType::TensorView || out->getValType().value() == ValType::TensorIndex); @@ -764,37 +474,14 @@ ReductionOp::ReductionOp( addOutput(out); addInput(in); + addAttribute(init); + addAttribute(IrBuilder::create>( + passkey.ir_container_, reduction_op_type)); + addAttribute( + IrBuilder::create>(passkey.ir_container_, is_allreduce)); } -ReductionOp::ReductionOp(const ReductionOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - reduction_op_type_(src->reduction_op_type_), - init_(ir_cloner->clone(src->init_)), - out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)), - is_allreduce_(src->is_allreduce_) {} - -Expr* ReductionOp::shallowCopy() const { - auto result = IrBuilder::create( - reduction_op_type_, init_, out_, in_, is_allreduce_); - result->copyPredicatesFrom(this); - return result; -} - -bool ReductionOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - // Note that init is not part of input vals, so it must be checked separately. - return ( - Expr::sameAs(other) && - getReductionOpType() == other_op->getReductionOpType() && - init()->sameAs(other_op->init())); -} +NVFUSER_DEFINE_CLONE_AND_CREATE(ReductionOp) GroupedReductionOp::GroupedReductionOp( IrBuilderPasskey passkey, @@ -803,10 +490,7 @@ GroupedReductionOp::GroupedReductionOp( std::vector outputs, std::vector inputs, bool is_fused) - : Expr(passkey), - reduction_op_types_(std::move(reduction_op_types)), - init_vals_(std::move(init_vals)), - is_allreduce_(is_fused) { + : Expr(passkey) { for (auto out : outputs) { addOutput(out); } @@ -814,23 +498,19 @@ GroupedReductionOp::GroupedReductionOp( for (auto in : inputs) { addInput(in); } -} -GroupedReductionOp::GroupedReductionOp( - const GroupedReductionOp* src, - IrCloner* ir_cloner) - : Expr(src, ir_cloner), - reduction_op_types_(src->reduction_op_types_), - init_vals_(ir_cloner->clone(src->init_vals_)), - is_allreduce_(src->is_allreduce_) {} + addAttribute(IrBuilder::create>>( + passkey.ir_container_, std::move(reduction_op_types))); + addAttribute( + IrBuilder::create>(passkey.ir_container_, is_fused)); -Expr* GroupedReductionOp::shallowCopy() const { - auto result = IrBuilder::create( - reduction_op_types_, init_vals_, outputs(), inputs(), is_allreduce_); - result->copyPredicatesFrom(this); - return result; + for (auto init : init_vals) { + addAttribute(init); + } } +NVFUSER_DEFINE_CLONE_AND_CREATE(GroupedReductionOp) + int GroupedReductionOp::getExprIndexOfOutput(Val* output_val) const { auto it = std::find(outputs().begin(), outputs().end(), output_val); if (it != outputs().end()) { @@ -841,28 +521,34 @@ int GroupedReductionOp::getExprIndexOfOutput(Val* output_val) const { false, "Not an output, ", output_val->toString(), ", of ", toString()); } -bool GroupedReductionOp::sameAs(const Statement* other) const { - if (this == other) { - return true; +c10::optional WelfordTriplet::getNameOf( + Val* val) const { + auto it = std::find(begin(), end(), val); + if (it != end()) { + return indexToValName(std::distance(begin(), it)); } - auto grouped_rop = dynamic_cast(other); - if (grouped_rop == nullptr) { - return false; - } + return c10::optional(); +} - if (!Expr::sameAs(other) || - getReductionOpTypes() != grouped_rop->getReductionOpTypes()) { - return false; - } +bool WelfordTriplet::sameAs(const WelfordTriplet& other) const { + return this == &other || + (avg()->sameAs(other.avg()) && var()->sameAs(other.var()) && + N()->sameAs(other.N())); +} - for (const auto i : c10::irange(numExprs())) { - if (!initVal(i)->sameAs(grouped_rop->initVal(i))) { - return false; - } - } +WelfordTriplet WelfordTriplet::clone(IrCloner* ir_cloner) const { + return transform([&](const Val* val) { return ir_cloner->clone(val); }); +} - return true; +std::vector WelfordTriplet::clone( + const std::vector& src, + IrCloner* ir_cloner) { + std::vector cloned; + for (const auto& triplet : src) { + cloned.emplace_back(triplet.clone(ir_cloner)); + } + return cloned; } WelfordOp::WelfordOp( @@ -871,11 +557,7 @@ WelfordOp::WelfordOp( const WelfordTriplet& input, const WelfordTriplet& init, bool is_fused) - : Expr(passkey), - output_(output), - input_(input), - init_(init), - is_allreduce_(is_fused) { + : Expr(passkey) { // Previously, nullptr was accepted and implicitly replaced by // default values. Looks like we always pass some non-null values, // so removed the implicit default behavior for code simplicity. @@ -909,74 +591,50 @@ WelfordOp::WelfordOp( // initial value with a count of 1 is un-common enough that I'll push // the responsibility of creating all-zero var tensors to the user TORCH_INTERNAL_ASSERT( - init_.avg()->getValType().value() == ValType::TensorView || - init_.avg()->getValType().value() == ValType::TensorIndex); + init.avg()->getValType().value() == ValType::TensorView || + init.avg()->getValType().value() == ValType::TensorIndex); TORCH_INTERNAL_ASSERT( - init_.var()->getValType().value() == ValType::TensorView || - init_.var()->getValType().value() == ValType::TensorIndex, + init.var()->getValType().value() == ValType::TensorView || + init.var()->getValType().value() == ValType::TensorIndex, "Invalid initial var: ", - init_.var()->toString()); + init.var()->toString()); } // check input TORCH_INTERNAL_ASSERT( - input_.avg()->getValType().value() == ValType::TensorView || - input_.avg()->getValType().value() == ValType::TensorIndex, - input_.avg()->getValType().value()); + input.avg()->getValType().value() == ValType::TensorView || + input.avg()->getValType().value() == ValType::TensorIndex, + input.avg()->getValType().value()); TORCH_INTERNAL_ASSERT( - input_.N()->getValType().value() == ValType::Scalar || - input_.N()->getValType().value() == ValType::TensorView || - input_.N()->getValType().value() == ValType::TensorIndex); - TORCH_INTERNAL_ASSERT(isIntegralType(input_.N()->dtype())); - if (!input_.N()->isOneInt()) { + input.N()->getValType().value() == ValType::Scalar || + input.N()->getValType().value() == ValType::TensorView || + input.N()->getValType().value() == ValType::TensorIndex); + TORCH_INTERNAL_ASSERT(isIntegralType(input.N()->dtype())); + if (!input.N()->isOneInt()) { // when input is only one value, only the value is required through avg // input the var part is implicitly 0 and codegen will handle that. TORCH_INTERNAL_ASSERT( - input_.var()->getValType().value() == ValType::TensorView || - input_.var()->getValType().value() == ValType::TensorIndex); + input.var()->getValType().value() == ValType::TensorView || + input.var()->getValType().value() == ValType::TensorIndex); } else { TORCH_INTERNAL_ASSERT( - input_.var() == nullptr || input_.var()->isZeroInt(), + input.var() == nullptr || input.var()->isZeroInt(), "Invalid var input, which must be either nullptr or scalar zero when the N input is one."); } - addOutput(output_.avg()); - addOutput(output_.var()); - addOutput(output_.N()); - - addInput(input_.avg()); - addInput(input_.var()); - addInput(input_.N()); -} - -c10::optional WelfordTriplet::getNameOf( - Val* val) const { - auto it = std::find(begin(), end(), val); - if (it != end()) { - return indexToValName(std::distance(begin(), it)); - } - - return c10::optional(); -} + addOutput(output.avg()); + addOutput(output.var()); + addOutput(output.N()); -bool WelfordTriplet::sameAs(const WelfordTriplet& other) const { - return this == &other || - (avg()->sameAs(other.avg()) && var()->sameAs(other.var()) && - N()->sameAs(other.N())); -} + addInput(input.avg()); + addInput(input.var()); + addInput(input.N()); -WelfordTriplet WelfordTriplet::clone(IrCloner* ir_cloner) const { - return transform([&](const Val* val) { return ir_cloner->clone(val); }); -} - -std::vector WelfordTriplet::clone( - const std::vector& src, - IrCloner* ir_cloner) { - std::vector cloned; - for (const auto& triplet : src) { - cloned.emplace_back(triplet.clone(ir_cloner)); - } - return cloned; + addAttribute(init.avg()); + addAttribute(init.var()); + addAttribute(init.N()); + addAttribute( + IrBuilder::create>(passkey.ir_container_, is_fused)); } WelfordOp::WelfordOp( @@ -998,22 +656,10 @@ WelfordOp::WelfordOp( WelfordTriplet(init_avg, init_var, init_N), is_fused) {} -WelfordOp::WelfordOp(const WelfordOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - output_(src->output_.clone(ir_cloner)), - input_(src->input_.clone(ir_cloner)), - init_(src->init_.clone(ir_cloner)), - is_allreduce_(src->is_allreduce_) {} - -Expr* WelfordOp::shallowCopy() const { - auto result = - IrBuilder::create(output_, input_, init_, is_allreduce_); - result->copyPredicatesFrom(this); - return result; -} +NVFUSER_DEFINE_CLONE_AND_CREATE(WelfordOp) Val* WelfordOp::getInitValOfOutput(Val* output_val) const { - auto val_name = output().getNameOf(output_val); + auto val_name = outputTriplet().getNameOf(output_val); TORCH_INTERNAL_ASSERT( val_name.has_value(), @@ -1022,21 +668,11 @@ Val* WelfordOp::getInitValOfOutput(Val* output_val) const { " of ", toString()); - return init().get(*val_name); -} - -bool WelfordOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (auto other_wop = dynamic_cast(other)) { - return input_.sameAs(other_wop->input_) && init_.sameAs(other_wop->init_); - } - return false; + return initTriplet().get(*val_name); } std::vector WelfordOp::getInitVals() const { - std::vector init_vals({init_.avg(), init_.var(), init_.N()}); + std::vector init_vals({initAvg(), initVar(), initN()}); return init_vals; } @@ -1046,43 +682,39 @@ GroupedWelfordOp::GroupedWelfordOp( std::vector input_vals, std::vector init_vals, bool is_allreduce) - : Expr(passkey), - output_vals_(std::move(output_vals)), - input_vals_(std::move(input_vals)), - init_vals_(std::move(init_vals)), - is_allreduce_(is_allreduce) { - const auto num_grouped_ops = output_vals_.size(); + : Expr(passkey) { + const auto num_grouped_ops = output_vals.size(); TORCH_INTERNAL_ASSERT( - input_vals_.size() == num_grouped_ops, + input_vals.size() == num_grouped_ops, "Invalid number of input arguments. Expected: ", num_grouped_ops, ", Given: ", - input_vals_.size()); + input_vals.size()); TORCH_INTERNAL_ASSERT( - init_vals_.size() == num_grouped_ops, + init_vals.size() == num_grouped_ops, "Invalid number of N arguments. Expected: ", num_grouped_ops, ", Given: ", - init_vals_.size()); + init_vals.size()); for (const auto i : c10::irange(num_grouped_ops)) { // Check output type TORCH_INTERNAL_ASSERT( - output_vals_[i].avg()->getValType().value() == ValType::TensorView || - output_vals_[i].avg()->getValType().value() == ValType::TensorIndex); + output_vals[i].avg()->getValType().value() == ValType::TensorView || + output_vals[i].avg()->getValType().value() == ValType::TensorIndex); TORCH_INTERNAL_ASSERT( - output_vals_[i].var()->getValType().value() == ValType::TensorView || - output_vals_[i].var()->getValType().value() == ValType::TensorIndex); + output_vals[i].var()->getValType().value() == ValType::TensorView || + output_vals[i].var()->getValType().value() == ValType::TensorIndex); TORCH_INTERNAL_ASSERT( - output_vals_[i].N()->getValType().value() == ValType::TensorView || - output_vals_[i].N()->getValType().value() == ValType::TensorIndex); - TORCH_INTERNAL_ASSERT(isIntegralType(output_vals_[i].N()->dtype())); + output_vals[i].N()->getValType().value() == ValType::TensorView || + output_vals[i].N()->getValType().value() == ValType::TensorIndex); + TORCH_INTERNAL_ASSERT(isIntegralType(output_vals[i].N()->dtype())); // check initial value - auto init_avg = init_vals_[i].avg(); - auto init_var = init_vals_[i].var(); - auto init_N = init_vals_[i].N(); + auto init_avg = init_vals[i].avg(); + auto init_var = init_vals[i].var(); + auto init_N = init_vals[i].N(); TORCH_INTERNAL_ASSERT( init_avg != nullptr && init_var != nullptr && init_N != nullptr, "nullptr init vals are not allowed"); @@ -1107,9 +739,9 @@ GroupedWelfordOp::GroupedWelfordOp( init_var->toString()); // check input - auto in_avg = input_vals_[i].avg(); - auto in_var = input_vals_[i].var(); - auto in_N = input_vals_[i].N(); + auto in_avg = input_vals[i].avg(); + auto in_var = input_vals[i].var(); + auto in_N = input_vals[i].N(); TORCH_INTERNAL_ASSERT( in_avg != nullptr && in_var != nullptr && in_N != nullptr, "nullptr input vals are not allowed"); @@ -1141,56 +773,22 @@ GroupedWelfordOp::GroupedWelfordOp( } } + addAttribute( + IrBuilder::create>(passkey.ir_container_, is_allreduce)); for (const auto i : c10::irange(num_grouped_ops)) { - addOutput(output_vals_[i].avg()); - addOutput(output_vals_[i].var()); - addOutput(output_vals_[i].N()); - addInput(input_vals_[i].avg()); - addInput(input_vals_[i].var()); - addInput(input_vals_[i].N()); + addOutput(output_vals[i].avg()); + addOutput(output_vals[i].var()); + addOutput(output_vals[i].N()); + addInput(input_vals[i].avg()); + addInput(input_vals[i].var()); + addInput(input_vals[i].N()); + addAttribute(init_vals[i].avg()); + addAttribute(init_vals[i].var()); + addAttribute(init_vals[i].N()); } } -GroupedWelfordOp::GroupedWelfordOp( - const GroupedWelfordOp* src, - IrCloner* ir_cloner) - : Expr(src, ir_cloner), - output_vals_(WelfordTriplet::clone(src->output_vals_, ir_cloner)), - input_vals_(WelfordTriplet::clone(src->input_vals_, ir_cloner)), - init_vals_(WelfordTriplet::clone(src->init_vals_, ir_cloner)), - is_allreduce_(src->is_allreduce_) {} - -Expr* GroupedWelfordOp::shallowCopy() const { - auto result = IrBuilder::create( - output_vals_, input_vals_, init_vals_, is_allreduce_); - result->copyPredicatesFrom(this); - return result; -} - -bool GroupedWelfordOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - - auto grouped_op = dynamic_cast(other); - if (grouped_op == nullptr) { - return false; - } - - if (!Expr::sameAs(other)) { - return false; - } - - for (const auto i : c10::irange(numExprs())) { - if (!initAvg(i)->sameAs(grouped_op->initAvg(i)) || - !initVar(i)->sameAs(grouped_op->initVar(i)) || - !initN(i)->sameAs(grouped_op->initN(i))) { - return false; - } - } - - return true; -} +NVFUSER_DEFINE_CLONE_AND_CREATE(GroupedWelfordOp) int GroupedWelfordOp::getExprIndexOfOutput(Val* output_val) const { for (const auto expr_idx : c10::irange(numExprs())) { @@ -1217,7 +815,7 @@ MmaOp::MmaOp( Val* in_a, Val* in_b, Val* init) - : Expr(passkey), out_(out), in_a_(in_a), in_b_(in_b), init_(init) { + : Expr(passkey) { // Check output type TORCH_INTERNAL_ASSERT( out->getValType().value() == ValType::TensorView || @@ -1236,6 +834,9 @@ MmaOp::MmaOp( addOutput(out); addInput(in_a); addInput(in_b); + addAttribute(init); + addAttribute( + IrBuilder::create>(passkey.ir_container_)); } MmaOp::MmaOp( @@ -1246,34 +847,21 @@ MmaOp::MmaOp( Val* init, OptionsInMma options) : MmaOp(passkey, out, in_a, in_b, init) { - options_ = options; + attribute(1)->as>()->value = options; } -MmaOp::MmaOp(const MmaOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - out_(ir_cloner->clone(src->out_)), - in_a_(ir_cloner->clone(src->in_a_)), - in_b_(ir_cloner->clone(src->in_b_)), - init_(ir_cloner->clone(src->init_)), - options_(src->options_) {} - -Expr* MmaOp::shallowCopy() const { - auto result = IrBuilder::create(out_, in_a_, in_b_, init_); - result->options_ = options_; - result->copyPredicatesFrom(this); - return result; -} +NVFUSER_DEFINE_CLONE_AND_CREATE(MmaOp) -bool MmaOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (auto other_mma = dynamic_cast(other)) { - return out_->sameAs(other_mma->out_) && in_a_->sameAs(other_mma->in_a_) && - in_b_->sameAs(other_mma->in_b_) && init_->sameAs(other_mma->init_) && - options_ == other_mma->options_; - } - return false; +void MmaOp::configureOptions(MmaOptions options) { + OptionsInMma& opt = attribute(1)->as>()->value; + TORCH_INTERNAL_ASSERT( + options.macro != MmaOptions::MacroType::NoMMA, + "Un-configured mma type from options."); + TORCH_INTERNAL_ASSERT( + options.accumulator_stride > 0, "Un-configured accumulator stride."); + opt.accumulator_stride = options.accumulator_stride; + opt.macro = options.macro; + opt.operand_layout = options.operand_layout; } TransposeOp::TransposeOp( @@ -1281,7 +869,7 @@ TransposeOp::TransposeOp( TensorView* out, TensorView* in, std::vector new2old) - : Expr(passkey), out_(out), in_(in), new2old_(std::move(new2old)) { + : Expr(passkey) { // Sanity check of the input parameters. Maybe not necessary as they // should be checked at function transpose. @@ -1289,44 +877,36 @@ TransposeOp::TransposeOp( TensorDomain::noReductions(in->getMaybeRFactorDomain()).size() == out->getMaybeRFactorDomain().size()); - TORCH_INTERNAL_ASSERT(new2old_.size() == out->getMaybeRFactorDomain().size()); + TORCH_INTERNAL_ASSERT(new2old.size() == out->getMaybeRFactorDomain().size()); // Make sure the entries of new2old are unique and range from 0 to // N-1, where N == new2old.size(). - std::set old_positions(new2old_.begin(), new2old_.end()); - TORCH_INTERNAL_ASSERT(old_positions.size() == new2old_.size()); + std::set old_positions(new2old.begin(), new2old.end()); + TORCH_INTERNAL_ASSERT(old_positions.size() == new2old.size()); // old_positions is sorted, so the first entry must be 0. TORCH_INTERNAL_ASSERT( *(old_positions.begin()) == 0, "Invalid new2old vector detected: ", - new2old_); + new2old); // The last entry must be N-1, since old_positions is sorted, starts // with 0, and its length is N. TORCH_INTERNAL_ASSERT( - *(old_positions.rbegin()) == (int)(new2old_.size() - 1), + *(old_positions.rbegin()) == (int)(new2old.size() - 1), "Invalid new2old vector detected: ", - new2old_); + new2old); addOutput(out); addInput(in); + addAttribute(IrBuilder::create>>( + passkey.ir_container_, std::move(new2old))); } -TransposeOp::TransposeOp(const TransposeOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)), - new2old_(src->new2old_) {} - -Expr* TransposeOp::shallowCopy() const { - auto result = IrBuilder::create(out_, in_, new2old_); - result->copyPredicatesFrom(this); - return result; -} +NVFUSER_DEFINE_CLONE_AND_CREATE(TransposeOp) std::vector TransposeOp::old2new() const { - std::vector old2new(new2old_.size()); - for (auto new_axis : c10::irange(new2old_.size())) { - auto old_axis = new2old_.at(new_axis); + std::vector old2new(new2old().size()); + for (auto new_axis : c10::irange(new2old().size())) { + auto old_axis = new2old().at(new_axis); old2new[old_axis] = new_axis; } return old2new; @@ -1337,13 +917,10 @@ ExpandOp::ExpandOp( TensorView* out, TensorView* in, std::vector _expanded_extents) - : Expr(passkey), - out_(out), - in_(in), - expanded_extents_(std::move(_expanded_extents)) { + : Expr(passkey) { addOutput(out); addInput(in); - for (auto expanded_extent : expanded_extents_) { + for (auto expanded_extent : _expanded_extents) { TORCH_INTERNAL_ASSERT(expanded_extent != nullptr); TORCH_INTERNAL_ASSERT( expanded_extent->dtype() == DataType::Int, @@ -1352,21 +929,7 @@ ExpandOp::ExpandOp( } } -ExpandOp::ExpandOp(const ExpandOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)) { - expanded_extents_.reserve(src->expanded_extents_.size()); - for (const auto expanded_extent : src->expanded_extents_) { - expanded_extents_.push_back(ir_cloner->clone(expanded_extent)); - } -} - -Expr* ExpandOp::shallowCopy() const { - auto result = IrBuilder::create(out_, in_, expanded_extents_); - result->copyPredicatesFrom(this); - return result; -} +NVFUSER_DEFINE_CLONE_AND_CREATE(ExpandOp) ShiftOp::ShiftOp( IrBuilderPasskey passkey, @@ -1374,14 +937,10 @@ ShiftOp::ShiftOp( Val* in, std::vector offsets, std::vector pad_width) - : Expr(passkey), - out_(out), - in_(in), - offsets_(std::move(offsets)), - pad_width_(std::move(pad_width)) { - // clang-tidy complains about out_ that it may be null. - TORCH_INTERNAL_ASSERT(out_ != nullptr); - TORCH_INTERNAL_ASSERT(in_ != nullptr); + : Expr(passkey) { + // clang-tidy complains about out that it may be null. + TORCH_INTERNAL_ASSERT(out != nullptr); + TORCH_INTERNAL_ASSERT(in != nullptr); auto out_type = out->getValType().value(); auto in_type = in->getValType().value(); @@ -1391,49 +950,28 @@ ShiftOp::ShiftOp( "Cannot shift a non-tensor object."); TORCH_INTERNAL_ASSERT( - offsets_.size() == - TensorDomain::noReductions(in_->as()->getRootDomain()) + offsets.size() == + TensorDomain::noReductions(in->as()->getRootDomain()) .size(), "Invalid offset vector: ", - offsets_); + offsets); TORCH_INTERNAL_ASSERT( - pad_width_.size() == - TensorDomain::noReductions(in_->as()->getRootDomain()) + pad_width.size() == + TensorDomain::noReductions(in->as()->getRootDomain()) .size(), "Invalid padding width vector: ", - pad_width_); + pad_width); addOutput(out); addInput(in); + addAttribute(IrBuilder::create>>( + passkey.ir_container_, std::move(offsets))); + addAttribute(IrBuilder::create>>( + passkey.ir_container_, std::move(pad_width))); } -ShiftOp::ShiftOp(const ShiftOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)), - offsets_(src->offsets_), - pad_width_(src->pad_width_) {} - -Expr* ShiftOp::shallowCopy() const { - auto result = IrBuilder::create(out_, in_, offsets_, pad_width_); - result->copyPredicatesFrom(this); - return result; -} - -bool ShiftOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - if (offsets() != other_op->offsets()) { - return false; - } - return Expr::sameAs(other); -} +NVFUSER_DEFINE_CLONE_AND_CREATE(ShiftOp) GatherOp::GatherOp( IrBuilderPasskey passkey, @@ -1441,14 +979,10 @@ GatherOp::GatherOp( Val* in, std::vector window_shape, std::vector> pad_width) - : Expr(passkey), - out_(out), - in_(in), - window_shape_(std::move(window_shape)), - pad_width_(std::move(pad_width)) { + : Expr(passkey) { // clang-tidy complains about out_ that it may be null. - TORCH_INTERNAL_ASSERT(out_ != nullptr); - TORCH_INTERNAL_ASSERT(in_ != nullptr); + TORCH_INTERNAL_ASSERT(out != nullptr); + TORCH_INTERNAL_ASSERT(in != nullptr); auto out_type = out->getValType().value(); auto in_type = in->getValType().value(); @@ -1458,52 +992,29 @@ GatherOp::GatherOp( "Cannot shift a non-tensor object."); const auto ndims = - TensorDomain::noReductions(in_->as()->getRootDomain()).size(); + TensorDomain::noReductions(in->as()->getRootDomain()).size(); TORCH_INTERNAL_ASSERT( - window_shape_.size() == ndims, + window_shape.size() == ndims, "Invalid window_shape vector: ", - window_shape_); + window_shape); TORCH_INTERNAL_ASSERT( - pad_width_.size() == ndims, "Invalid pad_width vector: ", pad_width_); + pad_width.size() == ndims, "Invalid pad_width vector: ", pad_width); - for (const auto& pad : pad_width_) { + for (const auto& pad : pad_width) { TORCH_INTERNAL_ASSERT( pad.size() == 2, "Padding size for each axis must have two Int vals."); } addOutput(out); addInput(in); + addAttribute(IrBuilder::create>>( + passkey.ir_container_, std::move(window_shape))); + addAttribute(IrBuilder::create>>>( + passkey.ir_container_, std::move(pad_width))); } -GatherOp::GatherOp(const GatherOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)), - window_shape_(src->window_shape_), - pad_width_(src->pad_width_) {} - -Expr* GatherOp::shallowCopy() const { - auto result = - IrBuilder::create(out_, in_, window_shape_, pad_width_); - result->copyPredicatesFrom(this); - return result; -} - -bool GatherOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - if (windowShape() != other_op->windowShape() || - padWidth() != other_op->padWidth()) { - return false; - } - return Expr::sameAs(other); -} +NVFUSER_DEFINE_CLONE_AND_CREATE(GatherOp) int GatherOp::gatherAxis(int axis) const { if (axis < 0) { @@ -1520,62 +1031,35 @@ ViewAsScalar::ViewAsScalar( Val* in, IterDomain* vector_id, Val* index) - : Expr(passkey), out_(out), in_(in), vector_id_(vector_id), index_(index) { + : Expr(passkey) { addOutput(out); addInput(in); + addAttribute(vector_id); + addAttribute(index); } -ViewAsScalar::ViewAsScalar(const ViewAsScalar* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)), - vector_id_(ir_cloner->clone(src->vector_id_)), - index_(ir_cloner->clone(src->index_)) {} - -Expr* ViewAsScalar::shallowCopy() const { - auto result = IrBuilder::create(out_, in_, vector_id_, index_); - result->copyPredicatesFrom(this); - return result; -} +NVFUSER_DEFINE_CLONE_AND_CREATE(ViewAsScalar) -ViewOp::ViewOp(IrBuilderPasskey passkey, TensorView* out, TensorView* in) - : Expr(passkey), out_(out), in_(in) { +ViewOp::ViewOp(IrBuilderPasskey passkey, Val* out, Val* in) : Expr(passkey) { addOutput(out); addInput(in); } -ViewOp::ViewOp(const ViewOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)) {} - -Expr* ViewOp::shallowCopy() const { - auto result = IrBuilder::create(out_, in_); - result->copyPredicatesFrom(this); - return result; -} +NVFUSER_DEFINE_CLONE_AND_CREATE(ViewOp) LoadStoreOp::LoadStoreOp( IrBuilderPasskey passkey, LoadStoreOpType op_type, Val* out, Val* in) - : Expr(passkey), load_store_type_(op_type), out_(out), in_(in) { + : Expr(passkey) { addOutput(out); addInput(in); + addAttribute(IrBuilder::create>( + passkey.ir_container_, op_type)); } -LoadStoreOp::LoadStoreOp(const LoadStoreOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - load_store_type_(src->load_store_type_), - out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)) {} - -Expr* LoadStoreOp::shallowCopy() const { - auto result = IrBuilder::create(load_store_type_, out_, in_); - result->copyPredicatesFrom(this); - return result; -} +NVFUSER_DEFINE_CLONE_AND_CREATE(LoadStoreOp) IterDomainBuilder::IterDomainBuilder(Val* _start, Val* _extent) : start_(_start), extent_(_extent) { @@ -1743,6 +1227,8 @@ IterDomain::IterDomain(const IterDomain* src, IrCloner* ir_cloner) padded_to_size_(src->padded_to_size_), is_mma_swizzled_(src->is_mma_swizzled_) {} +NVFUSER_DEFINE_CLONE(IterDomain) + bool IterDomain::sameAs(const Statement* other) const { if (other == this) { return true; @@ -2178,6 +1664,8 @@ TensorDomain::TensorDomain(const TensorDomain* src, IrCloner* ir_cloner) contiguity_(src->contiguity()), has_reduction_(src->has_reduction_) {} +NVFUSER_DEFINE_CLONE(TensorDomain) + bool TensorDomain::hasBlockBroadcast() const { return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) { return id->isBroadcast() && id->isThreadDim(); @@ -2650,44 +2138,29 @@ Split::Split( bool inner_split, Val* start_offset, Val* stop_offset) - : Expr(passkey), - outer_{outer}, - inner_{inner}, - in_{in}, - factor_{factor}, - inner_split_{inner_split}, - start_offset_{ - start_offset != nullptr ? start_offset - : passkey.ir_container_->zeroVal()}, - stop_offset_{ - stop_offset != nullptr ? stop_offset - : passkey.ir_container_->zeroVal()} { + : Expr(passkey) { TORCH_INTERNAL_ASSERT( - factor_->isAnInt(), + factor->isAnInt(), "Attempted to create a Split node with a non-integer factor."); + if (start_offset == nullptr) { + start_offset = passkey.ir_container_->zeroVal(); + } + if (stop_offset == nullptr) { + stop_offset = passkey.ir_container_->zeroVal(); + } addOutput(outer); addOutput(inner); addInput(in); // TODO add factor as an input, need to check Split::Split during validation // and need to check BestEffortReplay::findFirstMismatchedID addInput(factor); + addAttribute(factor); + addAttribute( + IrBuilder::create>(passkey.ir_container_, inner_split)); + addAttribute(start_offset); + addAttribute(stop_offset); } -Split::Split(const Split* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - outer_(ir_cloner->clone(src->outer_)), - inner_(ir_cloner->clone(src->inner_)), - in_(ir_cloner->clone(src->in_)), - factor_(ir_cloner->clone(src->factor_)), - inner_split_(src->inner_split_), - start_offset_(ir_cloner->clone(src->start_offset_)), - stop_offset_(ir_cloner->clone(src->stop_offset_)) {} - -Expr* Split::shallowCopy() const { - auto result = IrBuilder::create( - outer_, inner_, in_, factor_, inner_split_, start_offset_, stop_offset_); - result->copyPredicatesFrom(this); - return result; -} +NVFUSER_DEFINE_CLONE_AND_CREATE(Split) Val* Split::extent(Val* in_extent, Val* start_offset, Val* stop_offset) { TORCH_INTERNAL_ASSERT(in_extent != nullptr); @@ -2703,52 +2176,18 @@ Val* Split::extent(Val* in_extent, Val* start_offset, Val* stop_offset) { return in_extent; } -bool Split::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - return Expr::sameAs(other) && - factor()->sameAs(other->as()->factor()) && - innerSplit() == other->as()->innerSplit() && - startOffset()->sameAs(other->as()->startOffset()) && - stopOffset()->sameAs(other->as()->stopOffset()); -} - Merge::Merge( IrBuilderPasskey passkey, IterDomain* out, IterDomain* outer, IterDomain* inner) - : Expr(passkey), out_{out}, outer_{outer}, inner_{inner} { + : Expr(passkey) { addOutput(out); addInput(outer); addInput(inner); } -Merge::Merge(const Merge* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - out_(ir_cloner->clone(src->out_)), - outer_(ir_cloner->clone(src->outer_)), - inner_(ir_cloner->clone(src->inner_)) {} - -Expr* Merge::shallowCopy() const { - auto result = IrBuilder::create(out_, outer_, inner_); - result->copyPredicatesFrom(this); - return result; -} - -bool Merge::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - return Expr::sameAs(other); -} +NVFUSER_DEFINE_CLONE_AND_CREATE(Merge) Swizzle2D::Swizzle2D( IrBuilderPasskey passkey, @@ -2758,47 +2197,18 @@ Swizzle2D::Swizzle2D( IterDomain* in_y, Swizzle2DType swizzle_type, SwizzleMode swizzle_mode) - : Expr(passkey), - out_x_{out_x}, - out_y_{out_y}, - in_x_{in_x}, - in_y_{in_y}, - swizzle_type_(swizzle_type), - swizzle_mode_(swizzle_mode) { + : Expr(passkey) { addOutput(out_x); addOutput(out_y); addInput(in_x); addInput(in_y); + addAttribute(IrBuilder::create>( + passkey.ir_container_, swizzle_type)); + addAttribute(IrBuilder::create>( + passkey.ir_container_, swizzle_mode)); } -Expr* Swizzle2D::shallowCopy() const { - auto result = IrBuilder::create( - out_x_, out_y_, in_x_, in_y_, swizzle_type_, swizzle_mode_); - result->copyPredicatesFrom(this); - return result; -} - -bool Swizzle2D::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - if (!(swizzle_type_ == other->as()->swizzle_type_)) { - return false; - } - return Expr::sameAs(other); -} - -Swizzle2D::Swizzle2D(const Swizzle2D* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - out_x_(ir_cloner->clone(src->out_x_)), - out_y_(ir_cloner->clone(src->out_y_)), - in_x_(ir_cloner->clone(src->in_x_)), - in_y_(ir_cloner->clone(src->in_y_)), - swizzle_type_(src->swizzle_type_), - swizzle_mode_(src->swizzle_mode_) {} +NVFUSER_DEFINE_CLONE_AND_CREATE(Swizzle2D) NamedScalar::NamedScalar( IrBuilderPasskey passkey, @@ -2809,6 +2219,8 @@ NamedScalar::NamedScalar( NamedScalar::NamedScalar(const NamedScalar* src, IrCloner* ir_cloner) : Val(src, ir_cloner), name_(src->name_) {} +NVFUSER_DEFINE_CLONE(NamedScalar) + bool NamedScalar::sameAs(const Statement* other) const { if (this == other) { return true; diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index c1281e1a27a2..1ac47c139ed0 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -87,43 +88,114 @@ Val* TensorIndex::index(int i) const { return indices_[i]; } -BlockSync::BlockSync(IrBuilderPasskey passkey, bool war_sync) - : Expr(passkey), war_sync_(war_sync) { +Allocate::Allocate( + IrBuilderPasskey passkey, + Val* buffer, + MemoryType memory_type, + std::vector shape, + bool zero_init, + Allocate* alias) + : Expr(passkey) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); + if (!shape.empty()) { + TORCH_INTERNAL_ASSERT( + (shape.size() == 1 && shape[0]->isOneInt()) || + buffer->isA()); + } else { + TORCH_INTERNAL_ASSERT(buffer->isA()); + TORCH_INTERNAL_ASSERT( + buffer->as()->getMemoryType() == memory_type); + const auto domain = buffer->as()->domain(); + for (auto axis : domain->noReductions()) { + shape.push_back(axis->extent()); + } + } + + Val* size = nullptr; + for (auto s : shape) { + if (size == nullptr) { + size = s; + } else { + size = IrBuilder::mulExpr(size, s); + } + } + + if (size == nullptr) { + size = FusionGuard::getCurFusion()->oneVal(); + } + + if (alias != nullptr) { + TORCH_INTERNAL_ASSERT(alias != this, "Invalid alias"); + TORCH_INTERNAL_ASSERT(alias->memoryType() == memory_type, "Invalid alias"); + } + + addInput(size); + addAttribute(buffer); + addAttribute(IrBuilder::create>( + passkey.ir_container_, memory_type)); + addAttribute( + IrBuilder::create>(passkey.ir_container_, zero_init)); + + addAttribute(alias); + + for (auto s : shape) { + addAttribute(s); + } +} + +Allocate::Allocate( + IrBuilderPasskey passkey, + Val* buffer, + MemoryType memory_type, + Val* size, + bool zero_init) + : Allocate( + passkey, + buffer, + memory_type, + size == nullptr ? std::vector{} : std::vector{size}, + zero_init) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); } -Expr* BlockSync::shallowCopy() const { - auto result = IrBuilder::create(war_sync_); - result->copyPredicatesFrom(this); - return result; +NVFUSER_DEFINE_CLONE_AND_CREATE(Allocate) + +BlockSync::BlockSync(IrBuilderPasskey passkey, bool war_sync) : Expr(passkey) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); + addAttribute( + IrBuilder::create>(passkey.ir_container_, war_sync)); } +NVFUSER_DEFINE_CLONE_AND_CREATE(BlockSync) + GridSync::GridSync( IrBuilderPasskey passkey, ParallelTypeBitmap sync_dims, Val* sync_buffer) - : Expr(passkey), sync_dims_(sync_dims), sync_buffer_(sync_buffer) {} - -Expr* GridSync::shallowCopy() const { - auto result = IrBuilder::create(sync_dims_, sync_buffer_); - result->copyPredicatesFrom(this); - return result; + : Expr(passkey) { + addAttribute(IrBuilder::create>( + passkey.ir_container_, sync_dims)); + addAttribute(sync_buffer); } +NVFUSER_DEFINE_CLONE_AND_CREATE(GridSync) + CpAsyncWait::CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages) - : Expr(passkey), keep_stages_(keep_stages) { + : Expr(passkey) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); + addAttribute(IrBuilder::create>( + passkey.ir_container_, keep_stages)); } -Expr* CpAsyncWait::shallowCopy() const { - auto result = IrBuilder::create(keep_stages_); - result->copyPredicatesFrom(this); - return result; -} +NVFUSER_DEFINE_CLONE_AND_CREATE(CpAsyncWait) CpAsyncCommit::CpAsyncCommit(IrBuilderPasskey passkey) : Expr(passkey) { TORCH_INTERNAL_ASSERT( @@ -131,11 +203,7 @@ CpAsyncCommit::CpAsyncCommit(IrBuilderPasskey passkey) : Expr(passkey) { "IR type only valid for Kernel container."); } -Expr* CpAsyncCommit::shallowCopy() const { - auto result = IrBuilder::create(); - result->copyPredicatesFrom(this); - return result; -} +NVFUSER_DEFINE_CLONE_AND_CREATE(CpAsyncCommit) InitMagicZero::InitMagicZero(IrBuilderPasskey passkey) : Expr(passkey) { TORCH_INTERNAL_ASSERT( @@ -143,11 +211,7 @@ InitMagicZero::InitMagicZero(IrBuilderPasskey passkey) : Expr(passkey) { "IR type only valid for Kernel container."); } -Expr* InitMagicZero::shallowCopy() const { - auto result = IrBuilder::create(); - result->copyPredicatesFrom(this); - return result; -} +NVFUSER_DEFINE_CLONE_AND_CREATE(InitMagicZero) UpdateMagicZero::UpdateMagicZero(IrBuilderPasskey passkey) : Expr(passkey) { TORCH_INTERNAL_ASSERT( @@ -155,11 +219,7 @@ UpdateMagicZero::UpdateMagicZero(IrBuilderPasskey passkey) : Expr(passkey) { "IR type only valid for Kernel container."); } -Expr* UpdateMagicZero::shallowCopy() const { - auto result = IrBuilder::create(); - result->copyPredicatesFrom(this); - return result; -} +NVFUSER_DEFINE_CLONE_AND_CREATE(UpdateMagicZero) void Scope::insert(std::vector::const_iterator pos, Expr* expr) { exprs_.insert(pos, expr); @@ -232,33 +292,37 @@ ForLoop::ForLoop( Val* vectorize_shift, bool unroll_required, DoubleBufferLoopStage double_buffer_loop_stage) - : Expr(passkey), - iter_domain_{iter_domain}, - index_(index), - start_(start), - stop_(stop), - step_(step), - vectorize_(vectorize), - vectorize_shift_(vectorize_shift), - unroll_required_(unroll_required), - body_(this), - double_buffer_loop_stage_(double_buffer_loop_stage) { + : Expr(passkey) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); TORCH_INTERNAL_ASSERT(index->dtype() == DataType::Int); addInput(index); addInput(iter_domain); - if (start_ == nullptr && iter_domain->isThread()) { - start_ = NamedScalar::getParallelIndex(iter_domain->getParallelType()); + if (start == nullptr && iter_domain->isThread()) { + start = NamedScalar::getParallelIndex(iter_domain->getParallelType()); } - if (step_ == nullptr) { + if (step == nullptr) { if (iter_domain->isThread()) { - step_ = NamedScalar::getParallelDim(iter_domain->getParallelType()); + step = NamedScalar::getParallelDim(iter_domain->getParallelType()); } else { - step_ = FusionGuard::getCurFusion()->oneVal(); + step = FusionGuard::getCurFusion()->oneVal(); } } + addAttribute(start); + addAttribute(stop); + addAttribute(step); + addAttribute( + IrBuilder::create>(passkey.ir_container_, vectorize)); + addAttribute(vectorize_shift); + addAttribute(IrBuilder::create>( + passkey.ir_container_, unroll_required)); + addAttribute(IrBuilder::create>( + passkey.ir_container_, double_buffer_loop_stage)); + // Storing IR nodes as Attribute is not safe with IrCloner, but fortunately + // kernel IR does not need this feature. + addAttribute( + IrBuilder::create>(passkey.ir_container_, this)); } ForLoop::ForLoop(IrBuilderPasskey passkey, IterDomain* iter_domain) @@ -296,21 +360,7 @@ ForLoop::ForLoop(IrBuilderPasskey passkey, const ForLoop* other) "IR type only valid for Kernel container."); } -Expr* ForLoop::shallowCopy() const { - auto result = IrBuilder::create( - iter_domain_, - index_, - start_, - stop_, - step_, - vectorize_, - vectorize_shift_, - unroll_required_, - double_buffer_loop_stage_); - result->body_ = body_; - result->copyPredicatesFrom(this); - return result; -} +NVFUSER_DEFINE_CLONE_AND_CREATE(ForLoop) bool ForLoop::isUnrollable() const { // Start and stop must be constant, must not be a broadcast @@ -325,7 +375,7 @@ bool ForLoop::isUnrolled() const { if (isUnrollRequired() && !isUnrollable()) { TORCH_WARN( "Unroll required but not possible. Register allocation disabled. Loop index: ", - index_->toString()); + index()->toString()); return false; } @@ -356,28 +406,28 @@ bool ForLoop::isUnrolled() const { } Val* ForLoop::start() const { - if (start_ != nullptr) { - return start_; + if (attributeVal(0) != nullptr) { + return attributeVal(0); } else { // clang-tidy complains without this - TORCH_INTERNAL_ASSERT(iter_domain_ != nullptr); - return iter_domain_->start(); + TORCH_INTERNAL_ASSERT(iter_domain() != nullptr); + return iter_domain()->start(); } } Val* ForLoop::stop() const { - if (stop_ != nullptr) { - return stop_; + if (attributeVal(1) != nullptr) { + return attributeVal(1); } else { // clang-tidy complains without this - TORCH_INTERNAL_ASSERT(iter_domain_ != nullptr); - return iter_domain_->extent(); + TORCH_INTERNAL_ASSERT(iter_domain() != nullptr); + return iter_domain()->extent(); } } Val* ForLoop::step() const { - TORCH_INTERNAL_ASSERT(step_ != nullptr); - return step_; + TORCH_INTERNAL_ASSERT(attributeVal(2) != nullptr); + return attributeVal(2); } bool ForLoop::isTrivial() const { @@ -426,93 +476,18 @@ bool ForLoop::isTrivial() const { } IfThenElse::IfThenElse(IrBuilderPasskey passkey, Predicate* cond) - : Expr(passkey), then_body_(this), else_body_(this) { + : Expr(passkey) { setPredicate(cond); addInput(cond); + // Storing IR nodes as Attribute is not safe with IrCloner, but fortunately + // kernel IR does not need this feature. + addAttribute( + IrBuilder::create>(passkey.ir_container_, this)); + addAttribute( + IrBuilder::create>(passkey.ir_container_, this)); } -Expr* IfThenElse::shallowCopy() const { - auto result = IrBuilder::create(predicate()); - result->then_body_ = then_body_; - result->else_body_ = else_body_; - result->setWritePredicate(writePredicate()); - return result; -} - -Allocate::Allocate( - IrBuilderPasskey passkey, - Val* buffer, - MemoryType memory_type, - std::vector shape, - bool zero_init, - const Allocate* alias) - : Expr(passkey), - buffer_(buffer), - memory_type_(memory_type), - shape_(std::move(shape)), - zero_init_(zero_init), - alias_(alias) { - TORCH_INTERNAL_ASSERT( - passkey.ir_container_->isA(), - "IR type only valid for Kernel container."); - if (!shape_.empty()) { - TORCH_INTERNAL_ASSERT( - (shape_.size() == 1 && shape_[0]->isOneInt()) || - buffer_->isA()); - } else { - TORCH_INTERNAL_ASSERT(buffer_->isA()); - TORCH_INTERNAL_ASSERT( - buffer_->as()->getMemoryType() == memory_type_); - const auto domain = buffer_->as()->domain(); - for (auto axis : domain->noReductions()) { - shape_.push_back(axis->extent()); - } - } - - for (auto s : shape_) { - if (size_ == nullptr) { - size_ = s; - } else { - size_ = IrBuilder::mulExpr(size_, s); - } - } - - if (size_ == nullptr) { - size_ = FusionGuard::getCurFusion()->oneVal(); - } - - if (alias_ != nullptr) { - TORCH_INTERNAL_ASSERT(alias_ != this, "Invalid alias"); - TORCH_INTERNAL_ASSERT( - alias_->memoryType() == memory_type_, "Invalid alias"); - } - - addInput(size_); -} - -Allocate::Allocate( - IrBuilderPasskey passkey, - Val* buffer, - MemoryType memory_type, - Val* size, - bool zero_init) - : Allocate( - passkey, - buffer, - memory_type, - size == nullptr ? std::vector{} : std::vector{size}, - zero_init) { - TORCH_INTERNAL_ASSERT( - passkey.ir_container_->isA(), - "IR type only valid for Kernel container."); -} - -Expr* Allocate::shallowCopy() const { - auto result = - IrBuilder::create(buffer_, memory_type_, shape_, zero_init_); - result->copyPredicatesFrom(this); - return result; -} +NVFUSER_DEFINE_CLONE_AND_CREATE(IfThenElse) GridReduction::GridReduction( IrBuilderPasskey passkey, @@ -525,31 +500,23 @@ GridReduction::GridReduction( Val* entrance_index, Val* entrances, bool is_allreduce) - : ReductionOp(passkey, reduction_op_type, init, out, in, is_allreduce), - reduction_buffer_(reduction_buffer), - sync_buffer_(sync_buffer), - entrance_index_(entrance_index), - entrances_(entrances) { + : ReductionOp(passkey, reduction_op_type, init, out, in, is_allreduce) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); + TORCH_INTERNAL_ASSERT( + attributes().size() == num_reduction_op_attr, + "The num_reduction_op_attr does not match the number of attributes ReductionOp has." + "If you changed ReductionOp, please change num_reduction_op_attr accordingly."); + addAttribute(reduction_buffer); + addAttribute(sync_buffer); + addAttribute(entrance_index); + addAttribute(entrances); + addAttribute( + IrBuilder::create>(passkey.ir_container_)); } -Expr* GridReduction::shallowCopy() const { - auto result = IrBuilder::create( - getReductionOpType(), - init(), - out(), - in(), - reduction_buffer_, - sync_buffer_, - entrance_index_, - entrances_, - isAllreduce()); - result->copyPredicatesFrom(this); - result->thread_predicate_ = thread_predicate_; - return result; -} +NVFUSER_DEFINE_CLONE_AND_CREATE(GridReduction) GroupedGridReduction::GroupedGridReduction( IrBuilderPasskey passkey, @@ -569,54 +536,42 @@ GroupedGridReduction::GroupedGridReduction( std::move(init_vals), std::move(outputs), std::move(inputs), - is_allreduce), - reduction_buffers_(std::move(reduction_buffers)), - sync_buffer_(sync_buffer), - entrance_index_(entrance_index), - entrances_(entrances), - buffer_stride_(buffer_stride) { + is_allreduce) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); + TORCH_INTERNAL_ASSERT( + attributes().size() == numGroupedReductionOpAttr(), + "The numGroupedReductionOpAttr() does not match the number of attributes GroupedReductionOp has." + "If you changed GroupedReductionOp, please change numGroupedReductionOpAttr() accordingly."); + addAttribute(sync_buffer); + addAttribute(entrance_index); + addAttribute(entrances); + addAttribute(buffer_stride); + addAttribute( + IrBuilder::create>(passkey.ir_container_)); + for (auto buffer : reduction_buffers) { + addAttribute(buffer); + } } -Expr* GroupedGridReduction::shallowCopy() const { - auto result = IrBuilder::create( - getReductionOpTypes(), - initVals(), - outputs(), - inputs(), - reduction_buffers_, - sync_buffer_, - entrance_index_, - entrances_, - buffer_stride_, - isAllreduce()); - result->copyPredicatesFrom(this); - result->thread_predicate_ = thread_predicate_; - return result; -} +NVFUSER_DEFINE_CLONE_AND_CREATE(GroupedGridReduction) GridBroadcast::GridBroadcast( IrBuilderPasskey passkey, BroadcastOp* broadcast_op, Allocate* broadcast_buffer, Allocate* sync_buffer) - : Expr(passkey), - broadcast_op_(broadcast_op), - broadcast_buffer_(broadcast_buffer), - sync_buffer_(sync_buffer) { + : Expr(passkey) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); + addAttribute(broadcast_op); + addAttribute(broadcast_buffer); + addAttribute(sync_buffer); } -Expr* GridBroadcast::shallowCopy() const { - auto result = IrBuilder::create( - broadcast_op_, broadcast_buffer_, sync_buffer_); - result->copyPredicatesFrom(this); - return result; -} +NVFUSER_DEFINE_CLONE_AND_CREATE(GridBroadcast) GridWelford::GridWelford( IrBuilderPasskey passkey, @@ -627,32 +582,22 @@ GridWelford::GridWelford( Allocate* sync_buffer, Val* entrance_index, Val* entrances) - : Expr(passkey), - welford_op_(welford_op), - var_buffer_(var_buffer), - avg_buffer_(avg_buffer), - n_buffer_(n_buffer), - sync_buffer_(sync_buffer), - entrance_index_(entrance_index), - entrances_(entrances) { + : Expr(passkey) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); + addAttribute(welford_op); + addAttribute(var_buffer); + addAttribute(avg_buffer); + addAttribute(n_buffer); + addAttribute(sync_buffer); + addAttribute(entrance_index); + addAttribute(entrances); + addAttribute( + IrBuilder::create>(passkey.ir_container_)); } -Expr* GridWelford::shallowCopy() const { - auto result = IrBuilder::create( - welford_op_, - var_buffer_, - avg_buffer_, - n_buffer_, - sync_buffer_, - entrance_index_, - entrances_); - result->copyPredicatesFrom(this); - result->thread_predicate_ = thread_predicate_; - return result; -} +NVFUSER_DEFINE_CLONE_AND_CREATE(GridWelford) GroupedGridWelford::GroupedGridWelford( IrBuilderPasskey passkey, @@ -670,133 +615,79 @@ GroupedGridWelford::GroupedGridWelford( std::move(output_vals), std::move(input_vals), std::move(init_vals), - is_allreduce), - reduction_buffers_(std::move(reduction_buffers)), - sync_buffer_(sync_buffer), - entrance_index_(entrance_index), - entrances_(entrances), - buffer_stride_(buffer_stride) { + is_allreduce) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); -} - -Expr* GroupedGridWelford::shallowCopy() const { - auto result = IrBuilder::create( - outputVals(), - inputVals(), - initVals(), - reduction_buffers_, - sync_buffer_, - entrance_index_, - entrances_, - buffer_stride_, - isAllreduce()); - result->copyPredicatesFrom(this); - result->thread_predicate_ = thread_predicate_; - return result; -} - -AllocateFusedReduction::AllocateFusedReduction( - IrBuilderPasskey passkey, - GridReduction* grid_reduction) - : Expr(passkey), grid_expr_(grid_reduction) { TORCH_INTERNAL_ASSERT( - passkey.ir_container_->isA(), - "IR type only valid for Kernel container."); -} - -AllocateFusedReduction::AllocateFusedReduction( - IrBuilderPasskey passkey, - GridWelford* grid_welford) - : Expr(passkey), grid_expr_(grid_welford) { + attributes().size() == numGroupedWelfordOpAttr(), + "The numGroupedWelfordOpAttr() does not match the number of attributes GroupedWelfordOp has." + "If you changed GroupedReductionOp, please change numGroupedWelfordOpAttr() accordingly."); + addAttribute(sync_buffer); + addAttribute(entrance_index); + addAttribute(entrances); + addAttribute(buffer_stride); + addAttribute( + IrBuilder::create>(passkey.ir_container_)); TORCH_INTERNAL_ASSERT( - passkey.ir_container_->isA(), - "IR type only valid for Kernel container."); -} - -AllocateFusedReduction::AllocateFusedReduction( - IrBuilderPasskey passkey, - GroupedGridReduction* grouped_grid_reduction) - : Expr(passkey), grid_expr_(grouped_grid_reduction) { + reduction_buffers[0].size() == reduction_buffers[1].size()); TORCH_INTERNAL_ASSERT( - passkey.ir_container_->isA(), - "IR type only valid for Kernel container."); + reduction_buffers[0].size() == reduction_buffers[2].size()); + for (auto i : c10::irange(reduction_buffers[0].size())) { + addAttribute(reduction_buffers[0].at(i)); + addAttribute(reduction_buffers[1].at(i)); + addAttribute(reduction_buffers[2].at(i)); + } } +NVFUSER_DEFINE_CLONE_AND_CREATE(GroupedGridWelford) + AllocateFusedReduction::AllocateFusedReduction( IrBuilderPasskey passkey, - GroupedGridWelford* grouped_grid_welford) - : Expr(passkey), grid_expr_(grouped_grid_welford) { + Expr* grid_expr) + : Expr(passkey) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); + addAttribute(grid_expr); } -Expr* AllocateFusedReduction::shallowCopy() const { - if (grid_expr_->isA()) { - auto result = IrBuilder::create( - grid_expr_->as()); - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - return result; - } else if (grid_expr_->isA()) { - auto result = IrBuilder::create( - grid_expr_->as()); - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - return result; - } else if (grid_expr_->isA()) { - auto result = IrBuilder::create( - grid_expr_->as()); - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - return result; - } else if (grid_expr_->isA()) { - auto result = IrBuilder::create( - grid_expr_->as()); - result->setPredicate(predicate()); - result->setWritePredicate(writePredicate()); - return result; - } - TORCH_INTERNAL_ASSERT( - false, "Unknown reduction type in AllocateFusedReduction::shallowCopy"); -} +NVFUSER_DEFINE_CLONE_AND_CREATE(AllocateFusedReduction) TensorIndex* AllocateFusedReduction::out() const { - TORCH_INTERNAL_ASSERT(grid_expr_ != nullptr); - if (grid_expr_->isA() || - grid_expr_->isA()) { - return grid_expr_->outputs().at(0)->as(); - } else if (auto grid_welford = dynamic_cast(grid_expr_)) { + TORCH_INTERNAL_ASSERT(gridExpr() != nullptr); + if (gridExpr()->isA() || + gridExpr()->isA()) { + return gridExpr()->outputs().at(0)->as(); + } else if (auto grid_welford = dynamic_cast(gridExpr())) { return grid_welford->welford_op()->out()->as(); } else if ( auto grouped_grid_welford = - dynamic_cast(grid_expr_)) { + dynamic_cast(gridExpr())) { return grouped_grid_welford->out(0)->as(); } else { TORCH_INTERNAL_ASSERT( - false, "Invalid grid expression: ", grid_expr_->toString()); + false, "Invalid grid expression: ", gridExpr()->toString()); } } const ParallelTypeBitmap& AllocateFusedReduction::threadPredicate() const { - TORCH_INTERNAL_ASSERT(grid_expr_ != nullptr); - if (auto grid_reduction = dynamic_cast(grid_expr_)) { + TORCH_INTERNAL_ASSERT(gridExpr() != nullptr); + if (auto grid_reduction = dynamic_cast(gridExpr())) { return grid_reduction->threadPredicate(); - } else if (auto grid_welford = dynamic_cast(grid_expr_)) { + } else if (auto grid_welford = dynamic_cast(gridExpr())) { return grid_welford->threadPredicate(); } else if ( auto grouped_grid_reduction = - dynamic_cast(grid_expr_)) { + dynamic_cast(gridExpr())) { return grouped_grid_reduction->threadPredicate(); } else if ( auto grouped_grid_welford = - dynamic_cast(grid_expr_)) { + dynamic_cast(gridExpr())) { return grouped_grid_welford->threadPredicate(); } else { TORCH_INTERNAL_ASSERT( - false, "Invalid grid expression: ", grid_expr_->toString()); + false, "Invalid grid expression: ", gridExpr()->toString()); } } diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index 24f2869fcf05..763b32443c6d 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -156,6 +156,8 @@ class TORCH_CUDA_CU_API TensorIndex final : public Val { //! describes the output of an operation. class TORCH_CUDA_CU_API Allocate final : public Expr { public: + using Expr::Expr; + //! Allocation of a multi-dimensional buffer //! //! param shape Size of each dimension @@ -165,7 +167,7 @@ class TORCH_CUDA_CU_API Allocate final : public Expr { MemoryType memory_type, std::vector shape = {}, bool zero_init = false, - const Allocate* alias = nullptr); + Allocate* alias = nullptr); //! Allocation of a non-dimensional buffer //! @@ -181,44 +183,40 @@ class TORCH_CUDA_CU_API Allocate final : public Expr { return "Allocate"; } - Expr* shallowCopy() const override; + NVFUSER_DECLARE_CLONE_AND_CREATE Val* buffer() const { - return buffer_; + return attributeVal(0); } MemoryType memoryType() const { - return memory_type_; + return attribute(1)->as>()->value; } + //! Total size Val* size() const { - return size_; + return input(0); } - const std::vector& shape() const { - return shape_; + //! Size of each dimension + std::vector shape() const { + std::vector result; + result.reserve(attributes().size() - 4); + for (auto i = attributes().begin() + 4; i != attributes().end(); ++i) { + result.emplace_back((*i)->as()); + } + return result; } bool zeroInit() const { - return zero_init_; + return attribute(2)->as>()->value; } - const Allocate* alias() const { - return alias_; - } - - private: - Val* buffer_ = nullptr; - MemoryType memory_type_ = MemoryType::Local; - //! Size of each dimension - std::vector shape_; - bool zero_init_ = false; - //! Total size - Val* size_ = nullptr; - // This alias tracks the next Allocate node in a linked chain of aliases // If the alias is nullptr, then the Allocate node uses memory in the kernel - const Allocate* alias_ = nullptr; + const Allocate* alias() const { + return dynamic_cast(attribute(3)); + } }; // Sync represents __syncthreads barrier for block level coordination. @@ -227,43 +225,66 @@ class TORCH_CUDA_CU_API Allocate final : public Expr { // class TORCH_CUDA_CU_API BlockSync final : public Expr { public: + using Expr::Expr; + explicit BlockSync(IrBuilderPasskey passkey, bool war_sync = false); virtual const char* getOpString() const override { return "BlockSync"; } - Expr* shallowCopy() const override; + NVFUSER_DECLARE_CLONE_AND_CREATE + // TODO: war_sync_ is only used for testing/validation purposes. bool isWarHazardSync() const { - return war_sync_; + return attribute(0)->as>()->value; } +}; - private: - // TODO: war_sync_ is only used for testing/validation purposes. - bool war_sync_ = false; +// Synchronize all blocks in device, implies cooperative group launch is +// required. +class TORCH_CUDA_CU_API GridSync final : public Expr { + public: + using Expr::Expr; + + explicit GridSync( + IrBuilderPasskey passkey, + ParallelTypeBitmap sync_dims, + Val* sync_buffer); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + virtual const char* getOpString() const override { + return "GridSync"; + } + + ParallelTypeBitmap syncDims() const { + return attribute(0)->as>()->value; + } + + Val* syncBuffer() const { + return attributeVal(1); + } }; // CpAsyncWait represents wait intrinsics for cp.async class TORCH_CUDA_CU_API CpAsyncWait final : public Expr { public: + using Expr::Expr; + explicit CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages = 0); + NVFUSER_DECLARE_CLONE_AND_CREATE + virtual const char* getOpString() const override { return "CpAsyncWait"; } - Expr* shallowCopy() const override; - //! Returns the remaining number of stages that are not synchronized //! after this op. unsigned int keepStages() const { - return keep_stages_; + return attribute(0)->as>()->value; } - - private: - //! Number of stage to leave un-sync'ed by this op. - unsigned int keep_stages_ = 0; }; // CpAsyncCommit represents commit intrinsics for cp.async @@ -271,67 +292,45 @@ class TORCH_CUDA_CU_API CpAsyncWait final : public Expr { // to the async load hardware. Example usage see [Cicular buffer]. class TORCH_CUDA_CU_API CpAsyncCommit final : public Expr { public: - explicit CpAsyncCommit(IrBuilderPasskey passkey); + using Expr::Expr; - virtual const char* getOpString() const override { - return "CpAsyncCommit"; - } - - Expr* shallowCopy() const override; -}; + explicit CpAsyncCommit(IrBuilderPasskey passkey); -// Synchronize all blocks in device, implies cooperative group launch is -// required. -class TORCH_CUDA_CU_API GridSync final : public Expr { - public: - explicit GridSync( - IrBuilderPasskey passkey, - ParallelTypeBitmap sync_dims, - Val* sync_buffer); + NVFUSER_DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { - return "GridSync"; - } - - Expr* shallowCopy() const override; - - ParallelTypeBitmap syncDims() const { - return sync_dims_; - } - - Val* syncBuffer() const { - return sync_buffer_; + return "CpAsyncCommit"; } - - private: - ParallelTypeBitmap sync_dims_; - Val* sync_buffer_ = nullptr; }; // Simply prints "DEFINE_MAGIC_ZERO" in the code in accordance with magic_zero // in helpers.cu class TORCH_CUDA_CU_API InitMagicZero final : public Expr { public: + using Expr::Expr; + explicit InitMagicZero(IrBuilderPasskey passkey); + NVFUSER_DECLARE_CLONE_AND_CREATE + virtual const char* getOpString() const override { return "InitMagicZero"; } - - Expr* shallowCopy() const override; }; // Simply prints "UPDATE_MAGIC_ZERO" in the code in accordance with magic_zero // in helpers.cu class TORCH_CUDA_CU_API UpdateMagicZero final : public Expr { public: + using Expr::Expr; + explicit UpdateMagicZero(IrBuilderPasskey passkey); + NVFUSER_DECLARE_CLONE_AND_CREATE + virtual const char* getOpString() const override { return "UpdateMagicZero"; } - - Expr* shallowCopy() const override; }; // TODO(kir): promote to IR node @@ -386,6 +385,10 @@ class TORCH_CUDA_CU_API Scope { return owner_; } + bool operator==(const Scope&) const { + TORCH_INTERNAL_ASSERT(false, "Should not reach here"); + } + private: // Insert expr before pos void insert(std::vector::const_iterator pos, Expr* expr); @@ -412,6 +415,8 @@ class TORCH_CUDA_CU_API Scope { //! be smaller than the extent of iter_domain_. class TORCH_CUDA_CU_API ForLoop final : public Expr { public: + using Expr::Expr; + //! By default, start and stop are the same as those of iter_domain. //! Step is one by default. //! @@ -432,14 +437,14 @@ class TORCH_CUDA_CU_API ForLoop final : public Expr { ForLoop(IrBuilderPasskey passkey, const ForLoop* other); + NVFUSER_DECLARE_CLONE_AND_CREATE + virtual const char* getOpString() const override { return "ForLoop"; } - Expr* shallowCopy() const override; - Val* index() const { - return index_; + return input(0); } Val* start() const; @@ -448,38 +453,42 @@ class TORCH_CUDA_CU_API ForLoop final : public Expr { Val* step() const; + // [pre | vectorize | post] <= inner-most, merged root domain + // shift_ is applied to vectorize and post sections. Val* vectorize_shift() const { - return vectorize_shift_; + return attributeVal(4); } IterDomain* iter_domain() const { - return iter_domain_; + return input(1)->as(); } // TODO: Return pointer instead of reference to be more consistent Scope& body() { - return body_; + return attribute(7)->as>()->value; } const Scope& body() const { - return body_; + return attribute(7)->as>()->value; } + // vectorize is true when the for-loop contains a vectorize set + // the flag is used to omit the for-loop from the kernel bool vectorize() const { - return vectorize_; + return attribute(3)->as>()->value; } //! True if unrolled (i.e., "#pragma unroll" is attached) bool isUnrolled() const; - //! True if unrolling is required + //! True if unroll is required for avoiding stack allocation bool isUnrollRequired() const { - return unroll_required_; + return attribute(5)->as>()->value; } //! Set unrolling required void requireUnroll() { - unroll_required_ = true; + attribute(5)->as>()->value = true; } //! True if no actual for-loop is materialized @@ -488,37 +497,12 @@ class TORCH_CUDA_CU_API ForLoop final : public Expr { //! Returns the stage of a double buffered iterdomain //! that this for loop materializes. auto doubleBufferLoopStage() const { - return double_buffer_loop_stage_; + return attribute(6)->as>()->value; } private: //! Returns if a loop could be unrolled. bool isUnrollable() const; - - private: - IterDomain* const iter_domain_ = nullptr; - - Val* index_ = nullptr; - Val* start_ = nullptr; - Val* stop_ = nullptr; - Val* step_ = nullptr; - - // vectorize is true when the for-loop contains a vectorize set - // the flag is used to omit the for-loop from the kernel - bool vectorize_ = false; - // [pre | vectorize | post] <= inner-most, merged root domain - // shift_ is applied to vectorize and post sections. - Val* vectorize_shift_ = nullptr; - - //! True if unroll is required for avoiding stack allocation - bool unroll_required_ = false; - - Scope body_; - - //! Tracks if this for loop is implementing a stage of - //! a double buffered iterdomain. - DoubleBufferLoopStage double_buffer_loop_stage_ = - DoubleBufferLoopStage::NotApplicable; }; //! IfThenElse provides scoping for an boolean operator. Exprs placed in its @@ -530,36 +514,34 @@ class TORCH_CUDA_CU_API ForLoop final : public Expr { //! class TORCH_CUDA_CU_API IfThenElse final : public Expr { public: + using Expr::Expr; + explicit IfThenElse(IrBuilderPasskey passkey, Predicate* cond); + NVFUSER_DECLARE_CLONE_AND_CREATE + virtual const char* getOpString() const override { return "IfThenElse"; } - Expr* shallowCopy() const override; - Scope& thenBody() { - return then_body_; + return attribute(0)->as>()->value; } const Scope& thenBody() const { - return then_body_; + return attribute(0)->as>()->value; } Scope& elseBody() { - return else_body_; + return attribute(1)->as>()->value; } const Scope& elseBody() const { - return else_body_; + return attribute(1)->as>()->value; } bool hasElse() const { - return !else_body_.empty(); + return !elseBody().empty(); } - - private: - Scope then_body_; - Scope else_body_; }; //! Grid reduction operation @@ -570,7 +552,11 @@ class TORCH_CUDA_CU_API IfThenElse final : public Expr { //! This node provides FusionExecutor the information it needs to allocate the //! reduction and sync buffers. class TORCH_CUDA_CU_API GridReduction final : public ReductionOp { + static constexpr int num_reduction_op_attr = 3; + public: + using ReductionOp::ReductionOp; + GridReduction( IrBuilderPasskey passkey, BinaryOpType reduction_op_type, @@ -583,54 +569,57 @@ class TORCH_CUDA_CU_API GridReduction final : public ReductionOp { Val* entrances, bool is_allreduce = false); + NVFUSER_DECLARE_CLONE_AND_CREATE + virtual const char* getOpString() const override { return "GridReduction"; } - Expr* shallowCopy() const override; - Allocate* reduction_buffer() const { - return reduction_buffer_; + return attribute(num_reduction_op_attr)->as(); } Allocate* sync_buffer() const { - return sync_buffer_; + return attribute(num_reduction_op_attr + 1)->as(); } // Which instance of entering this grid reduction is this iteration? Val* entrance_index() const { - return entrance_index_; + return attributeVal(num_reduction_op_attr + 2); } // How many times will this grid reduction be entered Val* entrances() const { - return entrances_; + return attributeVal(num_reduction_op_attr + 3); } + // gridReduce has template flags for thread predicates. In order to + // use them, the thread predicate is held here separately from + // Expr::predicate_. const ParallelTypeBitmap& threadPredicate() const { - return thread_predicate_; + return attribute(num_reduction_op_attr + 4) + ->as>() + ->value; + } + + ParallelTypeBitmap& threadPredicate() { + return attribute(num_reduction_op_attr + 4) + ->as>() + ->value; } GridReduction* withThreadPredicate( const ParallelTypeBitmap& thread_predicate) { auto result = shallowCopy()->as(); - result->thread_predicate_ = thread_predicate; + result->threadPredicate() = thread_predicate; return result; } - - private: - Allocate* reduction_buffer_ = nullptr; - Allocate* sync_buffer_ = nullptr; - // gridReduce has template flags for thread predicates. In order to - // use them, the thread predicate is held here separately from - // Expr::predicate_. - ParallelTypeBitmap thread_predicate_; - Val* entrance_index_ = nullptr; - Val* entrances_ = nullptr; }; class TORCH_CUDA_CU_API GroupedGridReduction final : public GroupedReductionOp { public: + using GroupedReductionOp::GroupedReductionOp; + GroupedGridReduction( IrBuilderPasskey passkey, std::vector reduction_op_type, @@ -644,60 +633,72 @@ class TORCH_CUDA_CU_API GroupedGridReduction final : public GroupedReductionOp { Val* buffer_stride, bool is_allreduce = false); + NVFUSER_DECLARE_CLONE_AND_CREATE + + // number of attributes in the parent class + int numGroupedReductionOpAttr() const { + return 2 + outputs().size(); + } + virtual const char* getOpString() const override { return "GroupedGridReduction"; } - Expr* shallowCopy() const override; - - const std::vector& reduction_buffers() const { - return reduction_buffers_; + std::vector reduction_buffers() const { + auto offset = numGroupedReductionOpAttr() + 5; + auto size = outputs().size(); + std::vector result; + result.reserve(size); + for (auto i : c10::irange(offset, offset + size)) { + result.emplace_back(attribute(i)->as()); + } + return result; } Allocate* reduction_buffer(size_t i) const { - return reduction_buffers_.at(i); + return reduction_buffers().at(i); } Allocate* sync_buffer() const { - return sync_buffer_; + return attribute(numGroupedReductionOpAttr())->as(); } // Which instance of entering this grid reduction is this iteration? Val* entrance_index() const { - return entrance_index_; + return attributeVal(numGroupedReductionOpAttr() + 1); } // How many times will this grid reduction be entered Val* entrances() const { - return entrances_; + return attributeVal(numGroupedReductionOpAttr() + 2); } + // Stride of reduction buffers Val* buffer_stride() const { - return buffer_stride_; + return attributeVal(numGroupedReductionOpAttr() + 3); } + // gridReduce has template flags for thread predicates. In order to + // use them, the thread predicate is held here separately from + // Expr::predicate_. const ParallelTypeBitmap& threadPredicate() const { - return thread_predicate_; + return attribute(numGroupedReductionOpAttr() + 4) + ->as>() + ->value; + } + + ParallelTypeBitmap& threadPredicate() { + return attribute(numGroupedReductionOpAttr() + 4) + ->as>() + ->value; } GroupedGridReduction* withThreadPredicate( const ParallelTypeBitmap& thread_predicate) { auto result = shallowCopy()->as(); - result->thread_predicate_ = thread_predicate; + result->threadPredicate() = thread_predicate; return result; } - - private: - std::vector reduction_buffers_; - Allocate* sync_buffer_ = nullptr; - // gridReduce has template flags for thread predicates. In order to - // use them, the thread predicate is held here separately from - // Expr::predicate_. - ParallelTypeBitmap thread_predicate_; - Val* entrance_index_ = nullptr; - Val* entrances_ = nullptr; - // Stride of reduction buffers - Val* buffer_stride_ = nullptr; }; //! Grid broadcast operation @@ -709,34 +710,31 @@ class TORCH_CUDA_CU_API GroupedGridReduction final : public GroupedReductionOp { //! broadcast and sync buffers. class TORCH_CUDA_CU_API GridBroadcast final : public Expr { public: + using Expr::Expr; + GridBroadcast( IrBuilderPasskey passkey, BroadcastOp* broadcast_op, Allocate* broadcast_buffer, Allocate* sync_buffer); + NVFUSER_DECLARE_CLONE_AND_CREATE + virtual const char* getOpString() const override { return "GridBroadcast"; } - Expr* shallowCopy() const override; - BroadcastOp* broadcast_op() const { - return broadcast_op_; + return attribute(0)->as(); } Allocate* broadcast_buffer() const { - return broadcast_buffer_; + return attribute(1)->as(); } Allocate* sync_buffer() const { - return sync_buffer_; + return attribute(2)->as(); } - - private: - BroadcastOp* broadcast_op_ = nullptr; - Allocate* broadcast_buffer_ = nullptr; - Allocate* sync_buffer_ = nullptr; }; //! Grid welford operation @@ -750,6 +748,8 @@ class TORCH_CUDA_CU_API GridBroadcast final : public Expr { //! TODO: Make this a subclass of WelfordOp class TORCH_CUDA_CU_API GridWelford final : public Expr { public: + using Expr::Expr; + GridWelford( IrBuilderPasskey passkey, WelfordOp* welford_op, @@ -760,68 +760,63 @@ class TORCH_CUDA_CU_API GridWelford final : public Expr { Val* entrance_index, Val* entrances); + NVFUSER_DECLARE_CLONE_AND_CREATE + virtual const char* getOpString() const override { return "GridWelford"; } - Expr* shallowCopy() const override; - WelfordOp* welford_op() const { - return welford_op_; + return attribute(0)->as(); } Allocate* var_buffer() const { - return var_buffer_; + return attribute(1)->as(); } Allocate* avg_buffer() const { - return avg_buffer_; + return attribute(2)->as(); } Allocate* N_buffer() const { - return n_buffer_; + return attribute(3)->as(); } Allocate* sync_buffer() const { - return sync_buffer_; + return attribute(4)->as(); } // Which instance of entering this grid reduction is this iteration? Val* entrance_index() const { - return entrance_index_; + return attributeVal(5); } // How many times will this grid reduction be entered Val* entrances() const { - return entrances_; + return attributeVal(6); } + // gridReduce has template flags for thread predicates. In order to + // use them, the thread predicate is held here separately from + // Expr::predicate_. const ParallelTypeBitmap& threadPredicate() const { - return thread_predicate_; + return attribute(7)->as>()->value; + } + ParallelTypeBitmap& threadPredicate() { + return attribute(7)->as>()->value; } GridWelford* withThreadPredicate(const ParallelTypeBitmap& thread_predicate) { auto result = shallowCopy()->as(); - result->thread_predicate_ = thread_predicate; + result->threadPredicate() = thread_predicate; return result; } - - private: - WelfordOp* welford_op_ = nullptr; - Allocate* var_buffer_ = nullptr; - Allocate* avg_buffer_ = nullptr; - Allocate* n_buffer_ = nullptr; - Allocate* sync_buffer_ = nullptr; - Val* entrance_index_ = nullptr; - Val* entrances_ = nullptr; - // gridReduce has template flags for thread predicates. In order to - // use them, the thread predicate is held here separately from - // Expr::predicate_. - ParallelTypeBitmap thread_predicate_; }; class TORCH_CUDA_CU_API GroupedGridWelford final : public GroupedWelfordOp { public: + using GroupedWelfordOp::GroupedWelfordOp; + // input, output and init vals are vectors of triplets GroupedGridWelford( IrBuilderPasskey passkey, @@ -835,94 +830,117 @@ class TORCH_CUDA_CU_API GroupedGridWelford final : public GroupedWelfordOp { Val* buffer_stride, bool is_allreduce = false); + NVFUSER_DECLARE_CLONE_AND_CREATE + + int numGroupedWelfordOpAttr() const { + return 1 + outputs().size(); + } + virtual const char* getOpString() const override { return "GroupedGridWelford"; } - Expr* shallowCopy() const override; - - const std::array, 3>& reduction_buffers() const { - return reduction_buffers_; + std::array, 3> reduction_buffers() const { + auto offset = numGroupedWelfordOpAttr() + 5; + auto size = outputs().size() / 3; + std::array, 3> result; + result[0].reserve(size); + result[1].reserve(size); + result[2].reserve(size); + for (auto i : c10::irange(size)) { + result[0].emplace_back(attribute(offset + i * 3)->as()); + result[1].emplace_back(attribute(offset + i * 3 + 1)->as()); + result[2].emplace_back(attribute(offset + i * 3 + 2)->as()); + } + return result; } Allocate* sync_buffer() const { - return sync_buffer_; + return attribute(numGroupedWelfordOpAttr())->as(); } // Which instance of entering this grid reduction is this iteration? Val* entrance_index() const { - return entrance_index_; + return attributeVal(numGroupedWelfordOpAttr() + 1); } // How many times will this grid reduction be entered Val* entrances() const { - return entrances_; + return attributeVal(numGroupedWelfordOpAttr() + 2); } + // Stride of reduction buffers Val* buffer_stride() const { - return buffer_stride_; + return attributeVal(numGroupedWelfordOpAttr() + 3); } + // gridReduce has template flags for thread predicates. In order to + // use them, the thread predicate is held here separately from + // Expr::predicate_. const ParallelTypeBitmap& threadPredicate() const { - return thread_predicate_; + return attribute(numGroupedWelfordOpAttr() + 4) + ->as>() + ->value; + } + ParallelTypeBitmap& threadPredicate() { + return attribute(numGroupedWelfordOpAttr() + 4) + ->as>() + ->value; } GroupedGridWelford* withThreadPredicate( const ParallelTypeBitmap& thread_predicate) { auto result = shallowCopy()->as(); - result->thread_predicate_ = thread_predicate; + result->threadPredicate() = thread_predicate; return result; } - - private: - std::array, 3> reduction_buffers_; - Allocate* sync_buffer_ = nullptr; - // gridReduce has template flags for thread predicates. In order to - // use them, the thread predicate is held here separately from - // Expr::predicate_. - ParallelTypeBitmap thread_predicate_; - Val* entrance_index_ = nullptr; - Val* entrances_ = nullptr; - // Stride of reduction buffers - Val* buffer_stride_ = nullptr; }; // Allocate an instance of the fused reduction class. class TORCH_CUDA_CU_API AllocateFusedReduction final : public Expr { + explicit AllocateFusedReduction(IrBuilderPasskey passkey, Expr* grid_expr); + public: + using Expr::Expr; + explicit AllocateFusedReduction( IrBuilderPasskey passkey, - GridReduction* grid_reduction); + GridReduction* grid_reduction) + : AllocateFusedReduction(passkey, dynamic_cast(grid_reduction)) {} explicit AllocateFusedReduction( IrBuilderPasskey passkey, - GridWelford* grid_welford); + GridWelford* grid_welford) + : AllocateFusedReduction(passkey, dynamic_cast(grid_welford)) {} explicit AllocateFusedReduction( IrBuilderPasskey passkey, - GroupedGridReduction* grouped_grid_reduction); + GroupedGridReduction* grouped_grid_reduction) + : AllocateFusedReduction( + passkey, + dynamic_cast(grouped_grid_reduction)) {} explicit AllocateFusedReduction( IrBuilderPasskey passkey, - GroupedGridWelford* grouped_grid_welford); + GroupedGridWelford* grouped_grid_welford) + : AllocateFusedReduction( + passkey, + dynamic_cast(grouped_grid_welford)) {} + + NVFUSER_DECLARE_CLONE_AND_CREATE virtual const char* getOpString() const override { return "AllocateFusedReduction"; } - Expr* shallowCopy() const override; - + //! GridReduction, GridWelford, GroupedGridReduction or GroupedGridWelford Expr* gridExpr() const { - return grid_expr_; + return attribute(0)->asExpr(); } TensorIndex* out() const; const ParallelTypeBitmap& threadPredicate() const; - - private: - //! GridReduction, GridWelford, GroupedGridReduction or GroupedGridWelford - Expr* grid_expr_ = nullptr; }; } // namespace kir diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index 430cb31f7774..fdea5857fb4a 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -856,9 +856,12 @@ void IndexLowering::handle(const GroupedWelfordOp* grouped_wop) { std::vector indexed_outputs(grouped_wop->numExprs()); std::vector indexed_inputs(grouped_wop->numExprs()); + auto output_vals = grouped_wop->outputVals(); + auto input_vals = grouped_wop->inputVals(); + for (const auto i : c10::irange(grouped_wop->numExprs())) { - const auto& output = grouped_wop->outputVals().at(i); - const auto& input = grouped_wop->inputVals().at(i); + const auto& output = output_vals.at(i); + const auto& input = input_vals.at(i); WelfordTriplet indexed_output; WelfordTriplet indexed_input; for (const auto j : c10::irange(3)) { diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index 9fb69c7ae647..c490ef12daed 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -236,7 +236,7 @@ void OptOutMutator::mutate(RNGOp* rop) { Val* out = maybeMutated(rop->output(0)); Val* philox_idx = maybeMutated(rop->getPhiloxIndex()); - auto& parameters = rop->getParameters(); + auto parameters = rop->getParameters(); std::vector mutated_parameters; bool all_mutated_same = true; for (auto v : parameters) { diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index 93c24a2f5068..1f51e4c36b67 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -140,6 +140,8 @@ TensorView::TensorView( "Function invalid for kernel container."); } +NVFUSER_DEFINE_CLONE(TensorView) + void TensorView::convertRfactorToRootDomain() { // For a given TensorView, does its domain (root / rfactor) contain any // concrete sized extents? diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu1.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu1.cpp index ab7aa93f5130..eeefdb6ac750 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu1.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu1.cpp @@ -758,6 +758,7 @@ TEST_F(NVFuserTest, FusionRegister_CUDA) { // dummy expr with 2 outputs only for toposort test. struct DummyExpr : public Expr { + using Expr::Expr; ~DummyExpr() = default; DummyExpr( IrBuilderPasskey passkey, @@ -765,13 +766,13 @@ struct DummyExpr : public Expr { Val* _outrhs, Val* _lhs, Val* _rhs) - : Expr(passkey) // terribly safe :-D - { + : Expr(passkey) { addOutput(_outlhs); addOutput(_outrhs); addInput(_lhs); addInput(_rhs); } + NVFUSER_DECLARE_CLONE_AND_CREATE DummyExpr(const DummyExpr& other) = delete; DummyExpr& operator=(const DummyExpr& other) = delete; DummyExpr(DummyExpr&& other) = delete; @@ -779,11 +780,10 @@ struct DummyExpr : public Expr { virtual const char* getOpString() const override { return "DummyExpr"; } - Expr* shallowCopy() const override { - return nullptr; - } }; +NVFUSER_DEFINE_CLONE_AND_CREATE(DummyExpr) + TEST_F(NVFuserTest, FusionTopoSort_CUDA) { Fusion fusion; FusionGuard fg(&fusion); diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index 4cac24608647..ecff17dc5fee 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -33,7 +33,8 @@ enum class ValType { Scalar, NamedScalar, Predicate, - TensorIndex + TensorIndex, + Attribute }; // Manual - The user provides the Bool value. Predicate generation is bypassed.