diff --git a/torch/csrc/jit/codegen/cuda/arith.cpp b/torch/csrc/jit/codegen/cuda/arith.cpp index 40bb327ba323c..c217ba8b2c74f 100644 --- a/torch/csrc/jit/codegen/cuda/arith.cpp +++ b/torch/csrc/jit/codegen/cuda/arith.cpp @@ -553,6 +553,23 @@ TensorView* arange(Val* start, Val* end, Val* step, DataType dtype) { return out; } +TensorView* eye(Val* rows, Val* cols, DataType dtype) { + TORCH_CHECK(rows->getDataType() == DataType::Int, "rows must have type Int"); + TORCH_CHECK(cols->getDataType() == DataType::Int, "cols must have type Int"); + auto out = TensorViewBuilder() + .ndims(2) + .dtype(dtype) + .contiguity({true, true}) + .shape(std::vector{rows, cols}) + .build(); + IrBuilder::create(out, dtype); + return out; +} + +TensorView* eye(Val* size, DataType dtype) { + return eye(size, size, dtype); +} + // UNARY OPERATIONS #define NVFUSER_DEFINE_UNARY_OP(op_name, op_type) \ diff --git a/torch/csrc/jit/codegen/cuda/arith.h b/torch/csrc/jit/codegen/cuda/arith.h index 548e50406031c..8b6702ab1d372 100644 --- a/torch/csrc/jit/codegen/cuda/arith.h +++ b/torch/csrc/jit/codegen/cuda/arith.h @@ -156,6 +156,8 @@ TORCH_CUDA_CU_API TensorView* arange( Val* end, Val* step, DataType dtype = DataType::Int); +TORCH_CUDA_CU_API TensorView* eye(Val* size, DataType dtype); +TORCH_CUDA_CU_API TensorView* eye(Val* rows, Val* cols, DataType dtype); // UNARY OPERATIONS // abs diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index da4ca46efa543..8f21dac5bee8b 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -566,12 +566,20 @@ class CudaKernelGenerator : private OptOutConstDispatch { } void handle(const ARangeOp* aop) final { - auto index = genTensorIndex(aop->getLinearIndex()->as()); + auto index = + genTensorIndex(aop->getLinearLogicalIndex()->as()); indent() << gen(aop->output(0)) << " = arange<" << aop->dtype() << ">"; code_ << "(" << index << ", " << gen(aop->start()) << ", " << gen(aop->step()) << ");\n"; } + void handle(const EyeOp* aop) final { + auto index1 = gen(aop->getIndex1()); + auto index2 = gen(aop->getIndex2()); + indent() << gen(aop->output(0)) << " = (" << aop->dtype() << ")"; + code_ << "(" << index1 << " == " << index2 << ");\n"; + } + void handle(const UnaryOp* uop) final { bool is_vector_op = false; size_t vector_word_size = 1; diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp index 14079f3395060..70e9ae16375e5 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.cpp +++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp @@ -101,6 +101,9 @@ void Expr::dispatch(T handler, Expr* expr) { case ExprType::ARangeOp: ptr(handler)->handle(expr->as()); return; + case ExprType::EyeOp: + ptr(handler)->handle(expr->as()); + return; case ExprType::UnaryOp: ptr(handler)->handle(expr->as()); return; @@ -290,6 +293,9 @@ void Expr::constDispatch(T handler, const Expr* expr) { case ExprType::ARangeOp: ptr(handler)->handle(expr->as()); return; + case ExprType::EyeOp: + ptr(handler)->handle(expr->as()); + return; case ExprType::UnaryOp: ptr(handler)->handle(expr->as()); return; @@ -487,6 +493,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) { case ExprType::ARangeOp: ptr(mutator)->mutate(expr->as()); return; + case ExprType::EyeOp: + ptr(mutator)->mutate(expr->as()); + return; case ExprType::UnaryOp: ptr(mutator)->mutate(expr->as()); return; @@ -749,6 +758,9 @@ void OptOutConstDispatch::handle(const FullOp* stmt) { void OptOutConstDispatch::handle(const ARangeOp* stmt) { unhandled(stmt); } +void OptOutConstDispatch::handle(const EyeOp* stmt) { + unhandled(stmt); +} void OptOutConstDispatch::handle(const UnaryOp* stmt) { unhandled(stmt); } @@ -908,6 +920,9 @@ void OptOutDispatch::handle(FullOp* stmt) { void OptOutDispatch::handle(ARangeOp* stmt) { unhandled(stmt); } +void OptOutDispatch::handle(EyeOp* stmt) { + unhandled(stmt); +} void OptOutDispatch::handle(UnaryOp* stmt) { unhandled(stmt); } diff --git a/torch/csrc/jit/codegen/cuda/dispatch.h b/torch/csrc/jit/codegen/cuda/dispatch.h index 87e12e3fcddbc..4fea698191ec4 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.h +++ b/torch/csrc/jit/codegen/cuda/dispatch.h @@ -70,6 +70,7 @@ class NamedScalar; // Exprs class FullOp; class ARangeOp; +class EyeOp; class UnaryOp; class BinaryOp; class TernaryOp; @@ -147,6 +148,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase { // Exprs virtual void handle(const FullOp* stmt); virtual void handle(const ARangeOp* stmt); + virtual void handle(const EyeOp* stmt); virtual void handle(const UnaryOp* stmt); virtual void handle(const BinaryOp* stmt); virtual void handle(const TernaryOp* stmt); @@ -215,6 +217,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase { // Exprs virtual void handle(FullOp* stmt); virtual void handle(ARangeOp* stmt); + virtual void handle(EyeOp* stmt); virtual void handle(UnaryOp* stmt); virtual void handle(BinaryOp* stmt); virtual void handle(TernaryOp* stmt); @@ -324,6 +327,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase { // Exprs virtual void mutate(FullOp*); virtual void mutate(ARangeOp*); + virtual void mutate(EyeOp*); virtual void mutate(UnaryOp*); virtual void mutate(BinaryOp*); virtual void mutate(TernaryOp*); diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 5ad56bda15f21..8d5a499289186 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -1937,52 +1937,55 @@ std::vector Index::getNonGlobalProducerStridedIndices( return strided_inds; } -std::vector Index::getLinearIndex( - TensorView* consumer_tv, - const std::vector& loops) { +template +auto evaluateWithOverridenContiguity( + TensorView* tv, + bool contiguity, + const func_t& functor) -> decltype(functor()) { // Use domain guard to ignore the contiguity of // consumer tv. - TensorDomain* consumer_tv_no_contiguity_domain = nullptr; - auto contiguity_vector = - std::vector(consumer_tv->getMaybeRFactorDomain().size(), true); - if (consumer_tv->hasRFactor()) { - consumer_tv_no_contiguity_domain = IrBuilder::create( - consumer_tv->getRootDomain(), - consumer_tv->getRFactorDomain(), - consumer_tv->domain()->domain(), + TensorDomain* domain_with_specified_contiguity = nullptr; + std::vector contiguity_vector( + tv->getMaybeRFactorDomain().size(), contiguity); + if (tv->hasRFactor()) { + domain_with_specified_contiguity = IrBuilder::create( + tv->getRootDomain(), + tv->getRFactorDomain(), + tv->domain()->domain(), contiguity_vector); } else { - consumer_tv_no_contiguity_domain = IrBuilder::create( - consumer_tv->getRootDomain(), - consumer_tv->domain()->domain(), - contiguity_vector); + domain_with_specified_contiguity = IrBuilder::create( + tv->getRootDomain(), tv->domain()->domain(), contiguity_vector); } - ir_utils::TVDomainGuard domain_guard( - consumer_tv, consumer_tv_no_contiguity_domain); + ir_utils::TVDomainGuard domain_guard(tv, domain_with_specified_contiguity); - // TODO: - // More optimization on the underlying tensor layout - // will be done in a follow up. - return getGlobalConsumerStridedIndices(consumer_tv, loops); + return functor(); } -std::vector Index::getGlobalConsumerStridedIndices( - const TensorView* consumer_tv, +std::vector Index::getLinearLogicalIndex( + TensorView* consumer_tv, const std::vector& loops) { - FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalConsumerIndex"); - - auto gpu_lower = GpuLower::current(); - - auto index_from_id_graph = getTensorIndexFromIdGraph(loops, consumer_tv); + return evaluateWithOverridenContiguity(consumer_tv, true, [&]() { + return getGlobalConsumerStridedIndices(consumer_tv, loops); + }); +} - auto consumer_indexing = index_from_id_graph.index; +std::vector Index::getPerDimLogicalIndex( + TensorView* consumer_tv, + const std::vector& loops) { + return evaluateWithOverridenContiguity(consumer_tv, false, [&]() { + IndexFromIdGraph index_from_id_graph = + getTensorIndexFromIdGraph(loops, consumer_tv); + return getRootIndices(consumer_tv, loops, index_from_id_graph); + }); +} +std::vector Index::getStrides(const TensorView* tv) { // Indices should now be mapped onto IterDomains in consumer, so just grab // and use them. - auto root_dom = consumer_tv->getMaybeRFactorDomain(); + auto root_dom = tv->getMaybeRFactorDomain(); - // TODO: Abstract stride logic to reuse with producer indexing std::vector strides( root_dom.size(), GpuLower::current()->kernel()->oneVal()); { @@ -1993,14 +1996,13 @@ std::vector Index::getGlobalConsumerStridedIndices( continue; } std::stringstream ss; - ss << "T" << consumer_tv->name() << ".stride[" << stride_i++ << "]"; + ss << "T" << tv->name() << ".stride[" << stride_i++ << "]"; strides[i] = SimplifyingIrBuilder::create(ss.str(), DataType::Int); } } - TORCH_INTERNAL_ASSERT( - root_dom.size() == consumer_tv->domain()->contiguity().size()); + TORCH_INTERNAL_ASSERT(root_dom.size() == tv->domain()->contiguity().size()); Val* cur_contig_stride = GpuLower::current()->kernel()->oneVal(); for (const auto i : c10::irange(root_dom.size())) { auto dim = root_dom.size() - i - 1; @@ -2008,24 +2010,7 @@ std::vector Index::getGlobalConsumerStridedIndices( continue; } - Val* root_ind = nullptr; - if (consumer_indexing.indexMap().find(root_dom[dim]) != - consumer_indexing.indexMap().end()) { - root_ind = consumer_indexing.indexMap().at(root_dom[dim]); - } else if (root_dom[dim]->isBroadcast()) { - root_ind = GpuLower::current()->kernel()->zeroVal(); - } - - TORCH_INTERNAL_ASSERT( - root_ind != nullptr, - "Couldn't find root mapping for ", - consumer_tv->toString(), - " dim: ", - dim, - " id: ", - root_dom[dim]->toString()); - - if (consumer_tv->domain()->contiguity()[dim]) { + if (tv->domain()->contiguity()[dim]) { // If contig, used the stored stride which may be the previous // dimensions stride * previous dimensions size strides[dim] = cur_contig_stride; @@ -2041,12 +2026,18 @@ std::vector Index::getGlobalConsumerStridedIndices( strides[dim], getHaloExtentOfRootAxis(root_dom[dim])); } } + return strides; +} - auto vectorize_shift = - loops.empty() ? nullptr : loops.back()->vectorize_shift(); +std::vector Index::getRootIndices( + const TensorView* tv, + const std::vector& loops, + const IndexFromIdGraph& index_from_id_graph) { + auto gpu_lower = GpuLower::current(); + auto root_dom = tv->getMaybeRFactorDomain(); + auto indexing = index_from_id_graph.index; - // Global striding - std::vector strided_inds( + std::vector root_inds( root_dom.size(), GpuLower::current()->kernel()->zeroVal()); for (const auto i : c10::irange(root_dom.size())) { // See a comment in indexing to root domains in getGlobalProducerIndex. @@ -2057,22 +2048,21 @@ std::vector Index::getGlobalConsumerStridedIndices( } TORCH_INTERNAL_ASSERT( - consumer_indexing.indexMap().find(root_dom[i]) != - consumer_indexing.indexMap().end(), + indexing.indexMap().find(root_dom[i]) != indexing.indexMap().end(), "Couldn't find root mapping for ", - consumer_tv->toString(), + tv->toString(), " dim: ", i, " id: ", root_dom[i]->toString()); - auto root_ind = consumer_indexing.indexMap().at(root_dom[i]); + auto root_ind = indexing.indexMap().at(root_dom[i]); // index hoist must be done before the adjustments for halo root_ind = hoistConsumerIndex( root_dom[i], - consumer_tv, - consumer_indexing, + tv, + indexing, index_from_id_graph.resolved_loop_domains, index_from_id_graph.initial_concrete_index_map, loops, @@ -2080,12 +2070,33 @@ std::vector Index::getGlobalConsumerStridedIndices( root_ind = SimplifyingIrBuilder::addExpr( root_ind, getGlobalConsumerOffsetWithPartialSplit(root_dom[i])); + root_inds[i] = root_ind; + } + return root_inds; +} - if (root_ind->isZeroInt()) { +std::vector Index::getGlobalConsumerStridedIndices( + const TensorView* consumer_tv, + const std::vector& loops) { + FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalConsumerIndex"); + + auto index_from_id_graph = getTensorIndexFromIdGraph(loops, consumer_tv); + auto consumer_indexing = index_from_id_graph.index; + auto strides = getStrides(consumer_tv); + auto root_inds = getRootIndices(consumer_tv, loops, index_from_id_graph); + + // Global striding + auto vectorize_shift = + loops.empty() ? nullptr : loops.back()->vectorize_shift(); + std::vector strided_inds( + root_inds.size(), GpuLower::current()->kernel()->zeroVal()); + for (const auto i : c10::irange(root_inds.size())) { + if (root_inds[i]->isZeroInt()) { continue; } else { - auto strided_ind = SimplifyingIrBuilder::mulExpr(root_ind, strides[i]); - if (i == root_dom.size() - 1 && vectorize_shift != nullptr) { + auto strided_ind = + SimplifyingIrBuilder::mulExpr(root_inds[i], strides[i]); + if (i == strides.size() - 1 && vectorize_shift != nullptr) { strided_inds[i] = SimplifyingIrBuilder::addExpr(strided_ind, vectorize_shift); } else { diff --git a/torch/csrc/jit/codegen/cuda/index_compute.h b/torch/csrc/jit/codegen/cuda/index_compute.h index 43cde710fdfc4..5d8703c2e970e 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.h +++ b/torch/csrc/jit/codegen/cuda/index_compute.h @@ -62,6 +62,7 @@ namespace cuda { class ContigIDs; class LoopIndexing; +struct IndexFromIdGraph; class IndexCompute : public BackwardVisitor { protected: @@ -331,6 +332,15 @@ class Index { const TensorView* consumer, const std::vector& loops); + // get the strides of a tensor used for the index lowering + static std::vector getStrides(const TensorView* tv); + + // get the root indices of a tensor used for the index lowering + static std::vector getRootIndices( + const TensorView* tv, + const std::vector& loops, + const IndexFromIdGraph& index_from_id_graph); + public: // Indexing functions // Consumer = Producer @@ -363,19 +373,28 @@ class Index { const TensorView* consumer, const std::vector& loops); - //! Returns a vector of strided indices mapped onto the (rfactor) + //! Returns the logical index linearized from a multi-dimension address into a + //! linear memory address a consumer tensor. The returned index is intended to + //! be used for the computation of some tensor factories, such as: arange and + //! rand (for Philox pseudo random sequences) + static std::vector getLinearLogicalIndex( + TensorView* consumer_tv, + const std::vector& loops); + + //! Returns a vector of logical indices mapped onto the (rfactor) //! root domain of a consumer tensor. The returned index is intended - //! to be used to index into arange or Philox pseudo random sequences - static std::vector getLinearIndex( + //! to be used for the computation of some tensor factories, such as: + //! eye + static std::vector getPerDimLogicalIndex( TensorView* consumer_tv, const std::vector& loops); //! Take a consumer tensorview and loop nest and generates predicates //! associated with the concrete roots of the loop nest. Returns a list of - //! predicates, and a list of concrete roots they're associated with. It is - //! assumed that no predicate is required if index[i] is an index directly - //! from a for loop. This will not catch all cases if we actually have static - //! size information for example: + //! predicates, and a list of concrete roots they're associated with. It + //! is assumed that no predicate is required if index[i] is an index + //! directly from a for loop. This will not catch all cases if we actually + //! have static size information for example: //! //! TV[I].split(4) //! would produce the code: @@ -384,14 +403,14 @@ class Index { //! if( i * 4 + j < TV.size(0)) //! TV[i * 4 + j]... //! - //! However if we had TV.size[0] = 16 at "compile time" then we wouldn't need - //! the predicate. This will be caught by canOmitPredicate in the predicate - //! lowering + //! However if we had TV.size[0] = 16 at "compile time" then we wouldn't + //! need the predicate. This will be caught by canOmitPredicate in the + //! predicate lowering //! - //! unswitch_or_vec_loop is the for loop to start the unswitch like predicate, - //! this is not a bool value as if we have an unswitch loop with a vectorized - //! loop inside, we only want to base the "unswitch" like predicate on the - //! vectorized loop. + //! unswitch_or_vec_loop is the for loop to start the unswitch like + //! predicate, this is not a bool value as if we have an unswitch loop + //! with a vectorized loop inside, we only want to base the "unswitch" + //! like predicate on the vectorized loop. static std::vector getReferenceRootPredicates( TensorView* consumer_tv, const std::vector& loops, diff --git a/torch/csrc/jit/codegen/cuda/ir_builder.cpp b/torch/csrc/jit/codegen/cuda/ir_builder.cpp index 00568cbca882e..f0fd438c15672 100644 --- a/torch/csrc/jit/codegen/cuda/ir_builder.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_builder.cpp @@ -62,6 +62,7 @@ IR_BUILDER_INSTANTIATE(ViewAsScalar) IR_BUILDER_INSTANTIATE(ViewOp) IR_BUILDER_INSTANTIATE(FullOp) IR_BUILDER_INSTANTIATE(ARangeOp) +IR_BUILDER_INSTANTIATE(EyeOp) IR_BUILDER_INSTANTIATE(UnaryOp) IR_BUILDER_INSTANTIATE(BinaryOp) IR_BUILDER_INSTANTIATE(TernaryOp) diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp index 12816da31fbb1..489be49ddfc7c 100644 --- a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp @@ -96,6 +96,10 @@ void IrCloner::handle(const ARangeOp* op) { clone_ = IrBuilder::clone(op, this); } +void IrCloner::handle(const EyeOp* op) { + clone_ = IrBuilder::clone(op, this); +} + void IrCloner::handle(const UnaryOp* op) { clone_ = IrBuilder::clone(op, this); } diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.h b/torch/csrc/jit/codegen/cuda/ir_cloner.h index a70d4bbd7d15e..06e1ec3359d95 100644 --- a/torch/csrc/jit/codegen/cuda/ir_cloner.h +++ b/torch/csrc/jit/codegen/cuda/ir_cloner.h @@ -70,6 +70,7 @@ class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch { void handle(const FullOp*) override; void handle(const ARangeOp*) override; + void handle(const EyeOp*) override; void handle(const UnaryOp*) override; void handle(const BinaryOp*) override; void handle(const TernaryOp*) override; diff --git a/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp b/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp index bae4b97795b0c..6c04e4214b07d 100644 --- a/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp @@ -427,6 +427,14 @@ void IrGraphGenerator::handle(const ARangeOp* aop) { addArc(aop, aop->output(0)); } +void IrGraphGenerator::handle(const EyeOp* eop) { + // node + printExpr(eop, "eye"); + + // inputs & outputs + addArc(eop, eop->output(0)); +} + void IrGraphGenerator::handle(const UnaryOp* uop) { // node std::stringstream label; diff --git a/torch/csrc/jit/codegen/cuda/ir_graphviz.h b/torch/csrc/jit/codegen/cuda/ir_graphviz.h index e990ccdd7ac5d..1f555ed31ec06 100644 --- a/torch/csrc/jit/codegen/cuda/ir_graphviz.h +++ b/torch/csrc/jit/codegen/cuda/ir_graphviz.h @@ -84,6 +84,7 @@ class TORCH_CUDA_CU_API IrGraphGenerator : private OptInConstDispatch { void handle(const FullOp*) override; void handle(const ARangeOp*) override; + void handle(const EyeOp*) override; void handle(const UnaryOp*) override; void handle(const BinaryOp*) override; void handle(const TernaryOp*) override; diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index 9c617f5146e63..468bf8f6ce605 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -82,7 +82,7 @@ class TORCH_CUDA_CU_API ARangeOp : public Expr { return step_; } - Val* getLinearIndex() const { + Val* getLinearLogicalIndex() const { return linear_index_; } @@ -98,6 +98,63 @@ class TORCH_CUDA_CU_API ARangeOp : public Expr { Val* linear_index_ = nullptr; }; +// Tensor factory for generating identity matrices like +// +// [[1, 0, 0], +// [0, 1, 0], +// [0, 0, 1]] +// +// or +// +// [[1, 0, 0], +// [0, 1, 0], +// [0, 0, 1], +// [0, 0, 0]] +// +// or +// +// [[1, 0, 0, 0], +// [0, 1, 0, 0], +// [0, 0, 1, 0]] +class TORCH_CUDA_CU_API EyeOp : public Expr { + public: + EyeOp( + IrBuilderPasskey, + Val* out, + DataType dtype, + Val* index1 = nullptr, + Val* index2 = nullptr); + + EyeOp(const EyeOp* src, IrCloner* ir_cloner); + + bool sameAs(const Statement* other) const override; + + DataType dtype() const { + return dtype_; + } + + Val* getIndex1() const { + return index1_; + } + + void setIndex1(Val* index) { + index1_ = index; + } + + Val* getIndex2() const { + return index2_; + } + + void setIndex2(Val* index) { + index2_ = index; + } + + private: + const DataType dtype_; + Val* index1_ = nullptr; + Val* index2_ = nullptr; +}; + //! A specialization for Unary operations. Unary operations take in a single //! input and produce a single output. Examples include: //! 1) Casting operation i.e. float(a_val) diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp index e9ea7766e473d..4258d8c6b1377 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp @@ -302,6 +302,27 @@ void IrPrinter::handle(const ARangeOp* aop) { os_ << ";\n"; } +void IrPrinter::handle(const EyeOp* eop) { + if (!print_inline_) { + indent(); + os_ << eop->output(0) << "\n"; + indent_size_++; + indent(); + os_ << " = "; + } else { + checkInlineable(eop); + } + + os_ << "eye("; + handle(eop->input(0)); + os_ << ", " << eop->dtype() << ")"; + + indent_size_--; + + if (!print_inline_) + os_ << ";\n"; +} + void IrPrinter::handle(const UnaryOp* uop) { bool istvop = ir_utils::isTvOp(uop); if (!print_inline_) { diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.h b/torch/csrc/jit/codegen/cuda/ir_iostream.h index d38fd4a8a8cf5..599e50286d294 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.h +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.h @@ -84,6 +84,7 @@ class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch { void handle(const FullOp*) final; void handle(const ARangeOp*) final; + void handle(const EyeOp*) final; void handle(const UnaryOp*) final; void handle(const BinaryOp*) final; void handle(const TernaryOp*) final; diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index e2b697c85f4bd..0d8d04a89c888 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -242,6 +242,58 @@ ARangeOp::ARangeOp(const ARangeOp* src, IrCloner* ir_cloner) step_(ir_cloner->clone(src->step_)), linear_index_(ir_cloner->clone(src->linear_index_)) {} +EyeOp::EyeOp( + IrBuilderPasskey passkey, + Val* out, + DataType dtype, + Val* index1, + Val* index2) + : Expr(passkey, ExprType::EyeOp), + dtype_(dtype), + index1_(index1), + index2_(index2) { + if (out->isA()) { + addInput(out->as()->getRootDomain()[0]->extent()); + if (out->as()->getRootDomain()[1] != + out->as()->getRootDomain()[0]) { + addInput(out->as()->getRootDomain()[1]->extent()); + } + } + addOutput(out); +} + +EyeOp::EyeOp(const EyeOp* src, IrCloner* ir_cloner) + : Expr(src, ir_cloner), + dtype_(src->dtype_), + index1_(ir_cloner->clone(src->index1_)), + index2_(ir_cloner->clone(src->index2_)) {} + +bool EyeOp::sameAs(const Statement* other) const { + if (this == other) { + return true; + } + if (!other->isA()) { + return false; + } + const auto other_op = other->as(); + if (dtype_ != other_op->dtype_) { + return false; + } + if ((index1_ == nullptr) != (other_op->index1_ == nullptr)) { + return false; + } + if ((index2_ == nullptr) != (other_op->index2_ == nullptr)) { + return false; + } + if ((index1_ != nullptr) && !index1_->sameAs(other_op->index1_)) { + return false; + } + if ((index2_ != nullptr) && !index2_->sameAs(other_op->index2_)) { + return false; + } + return Expr::sameAs(other); +} + bool ARangeOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -300,8 +352,9 @@ bool UnaryOp::sameAs(const Statement* other) const { return false; } const auto other_op = other->as(); - if (getUnaryOpType() != other_op->getUnaryOpType()) + if (getUnaryOpType() != other_op->getUnaryOpType()) { return false; + } return Expr::sameAs(other); } @@ -336,8 +389,9 @@ bool BinaryOp::sameAs(const Statement* other) const { return false; } const auto other_op = other->as(); - if (getBinaryOpType() != other_op->getBinaryOpType()) + if (getBinaryOpType() != other_op->getBinaryOpType()) { return false; + } return Expr::sameAs(other); } @@ -376,8 +430,9 @@ bool TernaryOp::sameAs(const Statement* other) const { return false; } const auto other_op = other->as(); - if (getTernaryOpType() != other_op->getTernaryOpType()) + if (getTernaryOpType() != other_op->getTernaryOpType()) { return false; + } return Expr::sameAs(other); } diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.cpp b/torch/csrc/jit/codegen/cuda/ir_utils.cpp index 273c2bf912d6c..ff9075c244b67 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_utils.cpp @@ -208,7 +208,18 @@ struct SubstituteInExpr : public OptInDispatch { end, step, arange_expr->dtype(), - arange_expr->getLinearIndex()); + arange_expr->getLinearLogicalIndex()); + } + + void handle(EyeOp* eye_expr) final { + auto out = reference_->sameAs(eye_expr->output(0)) ? substitute_ + : eye_expr->output(0); + expr_ = IrBuilder::create( + eye_expr->container(), + out, + eye_expr->dtype(), + eye_expr->getIndex1(), + eye_expr->getIndex2()); } void handle(UnaryOp* unary_expr) final { diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index 25e22fc2e4b49..4719a5fd7bfdf 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -100,7 +100,7 @@ void IndexLowering::handle(const RNGOp* rop) { // TensorIndex for philox subsequence and component. auto philox_index = SimplifyingIrBuilder::create( - out_tv, Index::getLinearIndex(out_tv, for_loops_)); + out_tv, Index::getLinearLogicalIndex(out_tv, for_loops_)); // TensorIndex for writing rand_like output. const auto out = lowerDstIndex(out_tv); @@ -135,11 +135,11 @@ void IndexLowering::handle(const ARangeOp* aop) { auto out_tv = dynamic_cast(aop->output(0)); TORCH_INTERNAL_ASSERT(out_tv != nullptr); - // TensorIndex for philox subsequence and component. + // linear index for computing arange output auto linear_index = SimplifyingIrBuilder::create( - out_tv, Index::getLinearIndex(out_tv, for_loops_)); + out_tv, Index::getLinearLogicalIndex(out_tv, for_loops_)); - // TensorIndex for writing rand_like output. + // TensorIndex for writing arange output. const auto out = lowerDstIndex(out_tv); auto lowered = IrBuilder::create( out, aop->start(), aop->end(), aop->step(), aop->dtype(), linear_index); @@ -148,6 +148,24 @@ void IndexLowering::handle(const ARangeOp* aop) { GpuLower::current()->propagateExprInfo(aop, back()); } +void IndexLowering::handle(const EyeOp* eop) { + auto out_tv = dynamic_cast(eop->output(0)); + TORCH_INTERNAL_ASSERT(out_tv != nullptr); + + // linear index for computing eye output + auto indices = Index::getPerDimLogicalIndex(out_tv, for_loops_); + TORCH_INTERNAL_ASSERT(indices.size() == 2); + auto index1 = indices[0]; + auto index2 = indices[1]; + + // TensorIndex for writing eye output. + const auto out = lowerDstIndex(out_tv); + auto lowered = IrBuilder::create(out, eop->dtype(), index1, index2); + + pushBack(lowered); + GpuLower::current()->propagateExprInfo(eop, back()); +} + void IndexLowering::handle(const UnaryOp* uop) { const auto in = lowerSrcIndex(uop->in(), uop->out()); const auto out = lowerDstIndex(uop->out()); diff --git a/torch/csrc/jit/codegen/cuda/lower_index.h b/torch/csrc/jit/codegen/cuda/lower_index.h index df4f405fd3e60..6c08eeb195ea5 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.h +++ b/torch/csrc/jit/codegen/cuda/lower_index.h @@ -40,6 +40,7 @@ class TORCH_CUDA_CU_API IndexLowering : private OptOutConstDispatch { void handle(const FullOp*) final; void handle(const ARangeOp*) final; + void handle(const EyeOp*) final; void handle(const ViewAsScalar*) final; void handle(const UnaryOp*) final; diff --git a/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp b/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp index 2d4444d340903..e7d53ca59a93e 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp @@ -101,7 +101,7 @@ struct IndexingParameters { }; // Initial loop index map for global producer or consumer case. -IndexingParameters getGlobalIndexParameters( +IndexingParameters getLinearIndexParameters( const LoopIndexing& loop_indexing, bool index_producer = false) { IndexingParameters index_parameters; @@ -797,7 +797,7 @@ IndexFromIdGraph getTensorIndexFromIdGraph( } if (is_global) { - index_parameters = getGlobalIndexParameters(loop_indexing, index_producer); + index_parameters = getLinearIndexParameters(loop_indexing, index_producer); } else { index_parameters = getNonGlobalInitialIndexParameters( loop_indexing, consumer_tv, index_producer, producer_tv, p2c_map); diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index 5c326397530bb..85d533bccb2e6 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -94,6 +94,7 @@ bool isTvOp(const Expr* expr) { expr->getExprType().value() == ExprType::RNGOp || expr->getExprType().value() == ExprType::FullOp || expr->getExprType().value() == ExprType::ARangeOp || + expr->getExprType().value() == ExprType::EyeOp || expr->getExprType().value() == ExprType::ReductionOp || expr->getExprType().value() == ExprType::GroupedReductionOp || expr->getExprType().value() == ExprType::WelfordOp || diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index 8079d9f5ee8d7..e4f4d4a0e89ac 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -152,7 +152,19 @@ void OptOutMutator::mutate(ARangeOp* aop) { aop->end(), aop->step(), aop->dtype(), - aop->getLinearIndex()); + aop->getLinearLogicalIndex()); +} + +void OptOutMutator::mutate(EyeOp* eop) { + Val* out = maybeMutated(eop->output(0)); + + if (out->sameAs(eop->output(0))) { + return; + } + auto container = eop->container(); + container->removeExpr(eop); + IrBuilder::create( + container, out, eop->dtype(), eop->getIndex1(), eop->getIndex2()); } void OptOutMutator::mutate(UnaryOp* uop) { diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu b/torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu index 2b35692ca43ef..d7a2ce72896d8 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu @@ -166,6 +166,38 @@ TEST_F(NVFuserTest, FusionRNGManualScheduleValidateWithCURand_CUDA) { testValidate(fusion, {out}, {t0}, {ref}, __LINE__, __FILE__); } +TEST_F(NVFuserTest, FusionRNGManualScheduleValidateWithCURand2_CUDA) { + auto dtype = kFloat; + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + Int* size1 = IrBuilder::create(); + Int* size2 = IrBuilder::create(); + Int* size3 = IrBuilder::create(); + Int* size4 = IrBuilder::create(); + fusion->addInput(size1); + fusion->addInput(size2); + fusion->addInput(size3); + fusion->addInput(size4); + TensorView* tv0 = rand({size1, size2, size3, size4}, DataType::Float); + fusion->addOutput(tv0); + + auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); + + FusionExecutor fe; + fe.compileFusion(fusion, {10, 10, 10, 10}); + + at::manual_seed(0); + auto cg_outputs = fe.runFusion({10, 10, 10, 10}); + auto out = cg_outputs[0]; + + at::manual_seed(0); + auto ref = generate_uniform(10000, dtype).view({10, 10, 10, 10}); + + testValidate(fusion, {out}, {10, 10, 10, 10}, {ref}, __LINE__, __FILE__); +} + TEST_F(NVFuserTest, FusionBroadcastingRNG_CUDA) { for (auto dtype : {kFloat, kDouble}) { std::unique_ptr fusion_ptr = std::make_unique(); diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensor_factories.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensor_factories.cpp index 76c7b811db783..06e93fcd579e3 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensor_factories.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensor_factories.cpp @@ -278,6 +278,62 @@ TEST_F(NVFuserTest, FusionStandaloneARange_CUDA) { } } +TEST_F(NVFuserTest, FusionStandaloneEye_CUDA) { + auto sizes = {0, 1, 10, 17, 1024}; + auto dtypes = { + kBool, + kFloat, + kLong, + kDouble, + kHalf, + kBFloat16, + kInt, + kComplexFloat, + kComplexDouble}; + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + Val* size = IrBuilder::create(); + Val* maybe_m = IrBuilder::create(); + fusion->addInput(size); + fusion->addInput(maybe_m); + for (auto dtype : dtypes) { + if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) { + continue; + } + auto out_tv1 = eye(size, aten_to_data_type(dtype)); + fusion->addOutput(out_tv1); + auto out_tv2 = eye(size, maybe_m, aten_to_data_type(dtype)); + fusion->addOutput(out_tv2); + } + + FusionExecutorCache executor_cache(std::move(fusion)); + + for (auto size : sizes) { + std::vector expect; + expect.reserve(dtypes.size()); + for (auto dtype : dtypes) { + if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) { + continue; + } + const auto options = + at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); + expect.emplace_back(at::eye(size, options)); + expect.emplace_back(at::eye(size, 15, options)); + } + auto cg_outputs = executor_cache.runFusionWithInputs({size, 15}); + + testValidate( + executor_cache.fusion(), + cg_outputs, + {size, 15}, + expect, + __LINE__, + __FILE__); + } +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index e5f6e6878aba5..a333eb4d87ee2 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -303,10 +303,12 @@ static const char* predicate_type2string(PredicateType t) { static const char* expr_type2string(ExprType t) { switch (t) { - case ExprType::ARangeOp: - return "ARangeOp"; case ExprType::FullOp: return "FullOp"; + case ExprType::ARangeOp: + return "ARangeOp"; + case ExprType::EyeOp: + return "EyeOp"; case ExprType::UnaryOp: return "UnaryOp"; case ExprType::BinaryOp: diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index 066e1921df3c1..de4af398820fc 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -106,8 +106,9 @@ TORCH_CUDA_CU_API bool isSupportedTypeByDevice(DataType dtype); enum class ExprType { Invalid, - ARangeOp, FullOp, + ARangeOp, + EyeOp, UnaryOp, BinaryOp, TernaryOp,