diff --git a/third_party/nvfuser/benchmark/bert.cpp b/third_party/nvfuser/benchmark/bert.cpp index a9c98d5954a3..b55b235d760c 100644 --- a/third_party/nvfuser/benchmark/bert.cpp +++ b/third_party/nvfuser/benchmark/bert.cpp @@ -349,13 +349,13 @@ static void setupBiasDropoutAddLayernormBwd1(Fusion* fusion, DataType dtype) { TensorView* tv3 = TensorViewBuilder() .ndims(3) .dtype(dtype) - .contiguity({true, true}) + .contiguity({true, true, c10::nullopt}) .shape({-1, -1, 1}) .build(); TensorView* tv4 = TensorViewBuilder() .ndims(3) .dtype(dtype) - .contiguity({true, true}) + .contiguity({true, true, c10::nullopt}) .shape({-1, -1, 1}) .build(); @@ -457,7 +457,7 @@ static void setupBiasDropoutAddLayernormBwd2(Fusion* fusion, DataType dtype) { TensorView* tv4 = TensorViewBuilder() .ndims(3) .dtype(dtype) - .contiguity({true, true}) + .contiguity({true, true, c10::nullopt}) .shape({-1, -1, 1}) .build(); TensorView* tv5 = makeContigTensor(1, dtype); diff --git a/third_party/nvfuser/benchmark/layer_norm_backward.cpp b/third_party/nvfuser/benchmark/layer_norm_backward.cpp index 3bb5e9d966a9..d8b3a1e492be 100644 --- a/third_party/nvfuser/benchmark/layer_norm_backward.cpp +++ b/third_party/nvfuser/benchmark/layer_norm_backward.cpp @@ -27,12 +27,12 @@ static void setupLayerNorm_BWD(Fusion* fusion, DataType dtype) { auto bias = makeContigTensor(1, dtype); auto mean = TensorViewBuilder() - .contiguity({false}) + .contiguity({false, c10::nullopt}) .shape({-1, 1}) .dtype(DataType::Float) .build(); auto rstd = TensorViewBuilder() - .contiguity({false}) + .contiguity({false, c10::nullopt}) .shape({-1, 1}) .dtype(DataType::Float) .build(); diff --git a/third_party/nvfuser/benchmark/rms_norm_backward.cpp b/third_party/nvfuser/benchmark/rms_norm_backward.cpp index 9713b4c265c6..ed988a44eedd 100644 --- a/third_party/nvfuser/benchmark/rms_norm_backward.cpp +++ b/third_party/nvfuser/benchmark/rms_norm_backward.cpp @@ -27,7 +27,7 @@ static void setupRMSNorm_BWD(Fusion* fusion, DataType dtype) { auto input = makeContigTensor(3, dtype); auto weight = makeContigTensor(1, dtype); auto rstd = TensorViewBuilder() - .contiguity({false, false}) + .contiguity({false, false, c10::nullopt}) .shape({-1, -1, 1}) .dtype(dtype) .build(); diff --git a/third_party/nvfuser/benchmark/scale_bias_relu.cpp b/third_party/nvfuser/benchmark/scale_bias_relu.cpp index f9ed37ccbba2..9d98b6caded4 100644 --- a/third_party/nvfuser/benchmark/scale_bias_relu.cpp +++ b/third_party/nvfuser/benchmark/scale_bias_relu.cpp @@ -20,19 +20,17 @@ static void setupSBR(Fusion* fusion, DataType dtype) { std::vector bcast_shape(kNumberOfDims, 1); bcast_shape[bcast_shape.size() - 1] = -1; - std::vector bcast_contig(1, true); - auto x = makeContigTensor(kNumberOfDims, dtype); auto scale = TensorViewBuilder() - .contiguity(bcast_contig) .shape(bcast_shape) + .contiguity(true) .dtype(dtype) .build(); auto bias = TensorViewBuilder() - .contiguity(bcast_contig) .shape(bcast_shape) + .contiguity(true) .dtype(dtype) .build(); diff --git a/third_party/nvfuser/benchmark/timm.cpp b/third_party/nvfuser/benchmark/timm.cpp index 2a9a5a94de3e..6586fb8c76ae 100644 --- a/third_party/nvfuser/benchmark/timm.cpp +++ b/third_party/nvfuser/benchmark/timm.cpp @@ -16,12 +16,12 @@ static void setup_vit_base_patch16_224_bcast7(Fusion* fusion, void* null) { auto t3 = TensorViewBuilder() .shape({-1, -1, 1}) .dtype(DataType::Float) - .contiguity({true, true}) + .contiguity({true, true, c10::nullopt}) .build(); auto t4 = TensorViewBuilder() .shape({-1, -1, 1}) .dtype(DataType::Float) - .contiguity({true, true}) + .contiguity({true, true, c10::nullopt}) .build(); auto t7 = makeContigTensor(3, DataType::Half); @@ -538,14 +538,14 @@ static void setup_vit_base_patch16_224_LN_BWD(Fusion* fusion, void* null) { auto t5 = TensorViewBuilder() .shape({-1, -1, 1}) .dtype(DataType::Float) - .contiguity({true, true}) + .contiguity({true, true, c10::nullopt}) .build(); fusion->addInput(t5); auto t6 = TensorViewBuilder() .shape({-1, -1, 1}) .dtype(DataType::Float) - .contiguity({true, true}) + .contiguity({true, true, c10::nullopt}) .build(); fusion->addInput(t6); diff --git a/third_party/nvfuser/benchmark/utils.cpp b/third_party/nvfuser/benchmark/utils.cpp index a85e1707c868..f496b92edcbc 100644 --- a/third_party/nvfuser/benchmark/utils.cpp +++ b/third_party/nvfuser/benchmark/utils.cpp @@ -145,11 +145,7 @@ TensorView* makeSymbolicTensor(size_t ndims, DataType dtype) { } TensorView* makeContigTensor(size_t ndims, DataType dtype) { - return TensorViewBuilder() - .ndims(ndims) - .dtype(dtype) - .contiguity(std::vector(ndims, true)) - .build(); + return TensorViewBuilder().ndims(ndims).dtype(dtype).contiguity(true).build(); } TensorView* makeConcreteTensor(std::vector shape, DataType dtype) { @@ -159,18 +155,7 @@ TensorView* makeConcreteTensor(std::vector shape, DataType dtype) { TensorView* makeContigConcreteTensor( std::vector shape, DataType dtype) { - std::vector contiguity; - for (auto s : shape) { - if (s == 1) { - continue; - } - contiguity.push_back(true); - } - return TensorViewBuilder() - .shape(shape) - .dtype(dtype) - .contiguity(contiguity) - .build(); + return TensorViewBuilder().shape(shape).dtype(dtype).contiguity(true).build(); } void runBenchmarkIterations( diff --git a/third_party/nvfuser/csrc/contiguity.cpp b/third_party/nvfuser/csrc/contiguity.cpp index 1f44ea6e88db..de03fc2c2d01 100644 --- a/third_party/nvfuser/csrc/contiguity.cpp +++ b/third_party/nvfuser/csrc/contiguity.cpp @@ -386,7 +386,7 @@ NonDivisibleSplitDependencies::NonDivisibleSplitDependencies( ContigIDs::ContigIDs( const std::vector& ids, const std::vector& root_domain, - const std::vector& root_contiguity, + const std::vector>& root_contiguity, const std::unordered_set& final_ids, const std::unordered_map& index_map, const std::unordered_set& divisible_splits, @@ -419,7 +419,7 @@ ContigIDs::ContigIDs( ContigIDs::ContigIDs( const std::vector& ids, const std::vector& root_domain, - const std::vector& root_contiguity, + const std::vector>& root_contiguity, const std::unordered_set& final_ids, const std::unordered_map& index_map, const std::unordered_set& divisible_splits, @@ -458,17 +458,16 @@ void ContigIDs::build(const std::vector& ids) { } TORCH_INTERNAL_ASSERT( - TensorDomain::noBroadcasts(root_domain_).size() == - root_contiguity_.size(), + root_domain_.size() == root_contiguity_.size(), "Arguments don't match ", - TensorDomain::noBroadcasts(root_domain_).size(), + root_domain_.size(), " != ", root_contiguity_.size()); - int no_broadcast_i = 0; for (const auto root_domain_i : c10::irange(root_domain_.size())) { auto root_domain_id = root_domain_.at(root_domain_i)->as(); if (root_domain_id->isBroadcast()) { + TORCH_INTERNAL_ASSERT(!root_contiguity_.at(root_domain_i).has_value()); continue; } root_to_indexed_id_[root_domain_id] = root_domain_id; @@ -479,14 +478,13 @@ void ContigIDs::build(const std::vector& ids) { // rfactor root domains, which should just return "zero" // RootAxisInfo. This should be safe as no rfactor tensor should // need halo. - if (root_contiguity_.at(no_broadcast_i) && + if (*root_contiguity_.at(root_domain_i) && !halo_info_->getRootAxisInfo(root_domain_id).hasHalo() && root_domain_id->getIterType() != IterType::GatherScatter) { contig_ids_.emplace(root_domain_id); is_contig_root_.at(root_domain_id) = true; within_contig_ids_[root_domain_id] = std::unordered_set(); } - no_broadcast_i++; } if (!contig_ids_.empty()) { @@ -540,10 +538,10 @@ void ContigIDs::handle(Merge* merge) { bool is_indexing_pass = !ignore_consistent_ordering_; IterDomain* last_root = nullptr; - int no_broadcast_i = 0; for (auto root_id_i : c10::irange(root_domain_.size())) { auto root_id = root_domain_[root_id_i]; if (root_id->isBroadcast()) { + TORCH_INTERNAL_ASSERT(!root_contiguity_.at(root_id_i).has_value()); continue; } if (root_ids.has(root_id)) { @@ -556,14 +554,13 @@ void ContigIDs::handle(Merge* merge) { // If we're computing predicates (ignore_consistent_ordering_==true), // then we don't have this same constraint, we can just ignore // contiguity of the roots all together. - if (!root_contiguity_.at(no_broadcast_i) && is_indexing_pass) { + if (!*root_contiguity_.at(root_id_i) && is_indexing_pass) { if (!root_ids.empty()) { return; } } last_root = root_id; } - no_broadcast_i++; } // If there's a non_divisible split in the history of merge->out then it can't diff --git a/third_party/nvfuser/csrc/contiguity.h b/third_party/nvfuser/csrc/contiguity.h index 88fa1176dba8..fe42b61327e1 100644 --- a/third_party/nvfuser/csrc/contiguity.h +++ b/third_party/nvfuser/csrc/contiguity.h @@ -157,7 +157,7 @@ class ContigIDs : public OptInDispatch { ContigIDs( const std::vector& ids, const std::vector& root_domain, - const std::vector& root_contiguity, + const std::vector>& root_contiguity, const std::unordered_set& final_ids, const std::unordered_map& index_map, const std::unordered_set& divisible_splits, @@ -188,7 +188,7 @@ class ContigIDs : public OptInDispatch { ContigIDs( const std::vector& ids, const std::vector& root_domain, - const std::vector& root_contiguity, + const std::vector>& root_contiguity, const std::unordered_set& final_ids, const std::unordered_map& index_map, const std::unordered_set& divisible_splits, @@ -264,7 +264,7 @@ class ContigIDs : public OptInDispatch { //! Root domains to analyze contiguity const std::vector& root_domain_; //! Contiguity of root_domain_ - const std::vector& root_contiguity_; + const std::vector>& root_contiguity_; //! Domains where indexing/predicates cannot be done with their //! consumers domains const std::unordered_set& final_ids_; diff --git a/third_party/nvfuser/csrc/executor_kernel_arg.h b/third_party/nvfuser/csrc/executor_kernel_arg.h index a8a9c75bfbe7..d524ef29f20d 100644 --- a/third_party/nvfuser/csrc/executor_kernel_arg.h +++ b/third_party/nvfuser/csrc/executor_kernel_arg.h @@ -6,6 +6,7 @@ #include #include #include +#include namespace nvfuser { diff --git a/third_party/nvfuser/csrc/fusion_segmenter.cpp b/third_party/nvfuser/csrc/fusion_segmenter.cpp index 33f691bc2efc..f3ae19ef7cd1 100644 --- a/third_party/nvfuser/csrc/fusion_segmenter.cpp +++ b/third_party/nvfuser/csrc/fusion_segmenter.cpp @@ -780,7 +780,7 @@ TensorView* castIntermediateValueInCompleteFusion( return IrBuilder::create( IrBuilder::create( new_root_domain, - TensorDomain::getContiguousContiguity(new_root_domain)), + TensorDomain::getContiguityFilledWith(new_root_domain, true)), data_type); }; diff --git a/third_party/nvfuser/csrc/index_compute.cpp b/third_party/nvfuser/csrc/index_compute.cpp index 2fe61630ecc1..531243e92730 100644 --- a/third_party/nvfuser/csrc/index_compute.cpp +++ b/third_party/nvfuser/csrc/index_compute.cpp @@ -1470,11 +1470,8 @@ std::vector Index::getGlobalProducerStridedIndices( } } - auto no_broadcast_root_dom = TensorDomain::noBroadcasts(root_dom); TORCH_INTERNAL_ASSERT( - no_broadcast_root_dom.size() == - producer_tv->domain()->contiguity().size()); - auto full2nob_map = ir_utils::fullToNoBroadcastMap(root_dom); + root_dom.size() == producer_tv->domain()->contiguity().size()); Val* cur_contig_stride = GpuLower::current()->kernel()->oneVal(); for (const auto i : c10::irange(root_dom.size())) { auto dim = root_dom.size() - i - 1; @@ -1484,7 +1481,9 @@ std::vector Index::getGlobalProducerStridedIndices( if (root_dom[dim]->isBroadcast()) { strides[dim] = cur_contig_stride->fusion()->zeroVal(); - } else if (producer_tv->domain()->contiguity().at(full2nob_map.at(dim))) { + TORCH_INTERNAL_ASSERT( + !producer_tv->domain()->contiguity().at(dim).has_value()); + } else if (*producer_tv->domain()->contiguity().at(dim)) { // If contig, used the stored stride which may be the previous // dimensions stride * previous dimensions size strides[dim] = cur_contig_stride; @@ -1881,10 +1880,7 @@ std::vector Index::getStrides(const TensorView* tv) { } } - auto no_broadcast_root_dom = TensorDomain::noBroadcasts(root_dom); - TORCH_INTERNAL_ASSERT( - no_broadcast_root_dom.size() == tv->domain()->contiguity().size()); - auto full2nob_map = ir_utils::fullToNoBroadcastMap(root_dom); + TORCH_INTERNAL_ASSERT(root_dom.size() == tv->domain()->contiguity().size()); Val* cur_contig_stride = GpuLower::current()->kernel()->oneVal(); for (const auto i : c10::irange(root_dom.size())) { auto dim = root_dom.size() - i - 1; @@ -1894,7 +1890,8 @@ std::vector Index::getStrides(const TensorView* tv) { if (root_dom[dim]->isBroadcast()) { strides[dim] = cur_contig_stride->fusion()->zeroVal(); - } else if (tv->domain()->contiguity().at(full2nob_map.at(dim))) { + TORCH_INTERNAL_ASSERT(!tv->domain()->contiguity().at(dim).has_value()); + } else if (*tv->domain()->contiguity().at(dim)) { // If contig, used the stored stride which may be the previous // dimensions stride * previous dimensions size strides[dim] = cur_contig_stride; @@ -2312,12 +2309,8 @@ std::vector getPredicateContigIds( } std::unordered_set final_ids; - int no_broadcast_count = 0; for (auto root_i : c10::irange(consumer_root_domain.size())) { auto root_id = consumer_root_domain[root_i]; - if (!root_id->isBroadcast()) { - no_broadcast_count++; - } if (root_id->maybePartial()) { final_ids.insert(root_id); continue; @@ -2335,7 +2328,7 @@ std::vector getPredicateContigIds( ContigIDs contig_finder( consumer_tv->domain()->domain(), consumer_root_domain, - std::vector(no_broadcast_count, true), + TensorDomain::getContiguityFilledWith(consumer_root_domain, true), final_ids, concrete_index_map, GpuLower::current()->divisibleSplitSet(), diff --git a/third_party/nvfuser/csrc/ir_interface_nodes.h b/third_party/nvfuser/csrc/ir_interface_nodes.h index ce21222a698e..b1b94655d89b 100644 --- a/third_party/nvfuser/csrc/ir_interface_nodes.h +++ b/third_party/nvfuser/csrc/ir_interface_nodes.h @@ -243,12 +243,13 @@ class TORCH_CUDA_CU_API TensorView : public Val { //! expressions that use this TensorView are also updated. void convertRfactorToRootDomain(); - void setContiguity(const std::vector& contig) { + void setContiguity(const std::vector>& contig) { domain()->setContiguity(contig); } void setContiguity(bool contig) { - setContiguity(std::vector(domain()->contiguity().size(), contig)); + setContiguity( + TensorDomain::getContiguityFilledWith(getMaybeRFactorDomain(), contig)); } bool hasReduction() const; @@ -640,7 +641,8 @@ class TORCH_CUDA_CU_API TensorViewBuilder { TensorViewBuilder& dtype(DataType dtype); //! Set the contiguity information (default non-contiguous) - TensorViewBuilder& contiguity(std::vector contiguity); + TensorViewBuilder& contiguity(std::vector> contiguity); + TensorViewBuilder& contiguity(bool contiguity); //! Set the shape (default 0 dimensional, ie. scalar) TensorViewBuilder& shape(std::vector shape); @@ -655,7 +657,19 @@ class TORCH_CUDA_CU_API TensorViewBuilder { private: size_t ndims_ = 0; DataType dtype_ = DataType::Float; - std::vector contiguity_; + + // contiguity_ is the vector that you will pass to the constructor of + // TensorDomain. However, constructing this vector can be non-trivial, because + // it is required to be nullopt for broadcast dimensions. We often want to + // create contiguity vector that represents all contiguous or all + // discontiguous. uniform_contiguity_ is there to make this use case more + // convenient. If set, then TensorViewBuilder will automatically fill the + // contiguity with the value of uniform_contiguity_ where it is not required + // to be nullopt. Note that you can only set one of contiguity_ or + // uniform_contiguity_. + std::vector> contiguity_; + c10::optional uniform_contiguity_ = c10::nullopt; + std::vector shape_; std::vector expanded_; }; diff --git a/third_party/nvfuser/csrc/ir_internal_nodes.h b/third_party/nvfuser/csrc/ir_internal_nodes.h index 09bb096460c8..9710e909bd25 100644 --- a/third_party/nvfuser/csrc/ir_internal_nodes.h +++ b/third_party/nvfuser/csrc/ir_internal_nodes.h @@ -1725,20 +1725,20 @@ class TORCH_CUDA_CU_API TensorDomain : public Val { explicit TensorDomain( IrBuilderPasskey, std::vector root_domain, - std::vector contiguity = std::vector()); + std::vector> contiguity = {}); TensorDomain( IrBuilderPasskey, std::vector root_domain, std::vector domain, - std::vector contiguity = std::vector()); + std::vector> contiguity = {}); TensorDomain( IrBuilderPasskey, std::vector root_domain, std::vector rfactor_domain, std::vector domain, - std::vector contiguity = std::vector()); + std::vector> contiguity = {}); TensorDomain(const TensorDomain* src, IrCloner* ir_cloner); @@ -1768,22 +1768,29 @@ class TORCH_CUDA_CU_API TensorDomain : public Val { } // Note: [Contiguity] - // Contiguity is a bool vector which has the same number of elements as - // noBroadcasts(rfactor_domain_). The contiguity of a broadcast dimension is - // not defined. The contiguity of a non-broadcasting dimension is true if and - // only if it is memory dense with the next non-broadcasting dimension. + // Contiguity is a vector of optional which has the same number of + // elements as rfactor_domain_. The contiguity of a broadcast dimension is + // meaningless, so it has to be nullopt. The contiguity of a non-broadcasting + // dimension is true if and only if it is memory dense with the next + // non-broadcasting dimension. // For example, if I have a tensor torch.zeros(4, 1, 3).expand(-1, 10, -1), - // the contiguity will be (true, true), which means 4 is memory dense with 3. - const std::vector& contiguity() const { + // the contiguity will be (true, nullopt, true), which means 4 is memory dense + // with 3. + const std::vector>& contiguity() const { return contiguity_; } - void setContiguity(const std::vector& contig); + void setContiguity(const std::vector>& contig); std::string getContiguityString() const { std::stringstream ss; + bool first = true; for (auto b : contiguity()) { - ss << (b ? "t" : "f"); + if (!first) { + ss << " "; + } + first = false; + ss << (b.has_value() ? (*b ? "t" : "f") : "n"); } return ss.str(); } @@ -1884,10 +1891,12 @@ class TORCH_CUDA_CU_API TensorDomain : public Val { static bool hasBroadcast(const std::vector&); static bool hasReduction(const std::vector&); - // Get a vector whose size is the number of non-broadcast IDs in the given - // rfactor_domain filled with true. - static std::vector getContiguousContiguity( - const std::vector& rfactor_domain); + // Get a vector whose size is the number of IDs in the given rfactor_domain + // filled with fill_value or nullopt depending on whether its corresponding ID + // is broadcast. + static std::vector> getContiguityFilledWith( + const std::vector& rfactor_domain, + bool fill_value); // pair is in order where second is the consumer of first std::pair rFactor(const std::vector& axes); @@ -1898,7 +1907,7 @@ class TORCH_CUDA_CU_API TensorDomain : public Val { std::vector no_bcast_domain_; std::vector no_reduction_domain_; const std::vector rfactor_domain_; - std::vector contiguity_; + std::vector> contiguity_; bool has_reduction_; }; diff --git a/third_party/nvfuser/csrc/ir_iostream.cpp b/third_party/nvfuser/csrc/ir_iostream.cpp index 4a77553aac9a..e9e7b51a33f3 100644 --- a/third_party/nvfuser/csrc/ir_iostream.cpp +++ b/third_party/nvfuser/csrc/ir_iostream.cpp @@ -112,7 +112,7 @@ void IrTransformPrinter::printTransforms(TensorView* tv) { os() << ")\n"; } - os() << " contiguity: " << tv->domain()->contiguity() << "\n"; + os() << " contiguity: " << tv->domain()->getContiguityString() << "\n"; auto from = tv->getMaybeRFactorDomain(); auto all_exp = DependencyCheck::getAllExprsBetween( diff --git a/third_party/nvfuser/csrc/ir_nodes.cpp b/third_party/nvfuser/csrc/ir_nodes.cpp index a47450b821bb..12feb603e51a 100644 --- a/third_party/nvfuser/csrc/ir_nodes.cpp +++ b/third_party/nvfuser/csrc/ir_nodes.cpp @@ -2193,19 +2193,25 @@ Val* IterDomain::stop() const { TensorDomain::TensorDomain( IrBuilderPasskey passkey, std::vector root_domain, - std::vector contiguity) + std::vector> contiguity) : Val(passkey, ValType::TensorDomain, DataType::Null), root_domain_(std::move(root_domain)), contiguity_( - contiguity.empty() - ? std::vector(noBroadcasts(root_domain_).size(), false) - : std::move(contiguity)) { + contiguity.empty() ? getContiguityFilledWith(root_domain_, false) + : std::move(contiguity)) { TORCH_CHECK( - contiguity_.size() == noBroadcasts(getMaybeRFactorDomain()).size(), + contiguity_.size() == getMaybeRFactorDomain().size(), "Invalid contiguity information provided, incorrect size. Received vector of size ", contiguity_.size(), " but needed one of size ", - noBroadcasts(getMaybeRFactorDomain()).size()); + getMaybeRFactorDomain().size()); + for (auto i : c10::irange(contiguity_.size())) { + TORCH_CHECK( + getMaybeRFactorDomain().at(i)->isBroadcast() != + contiguity_.at(i).has_value(), + "The contiguity of a broadcast dimension must be None. " + "The contiguity of a non-broadcast dimension must be true/false"); + } // Just due to clang-tidy, correct value set in resetDomains has_reduction_ = false; @@ -2217,20 +2223,26 @@ TensorDomain::TensorDomain( IrBuilderPasskey passkey, std::vector root_domain, std::vector domain, - std::vector contiguity) + std::vector> contiguity) : Val(passkey, ValType::TensorDomain, DataType::Null), root_domain_(std::move(root_domain)), domain_(std::move(domain)), contiguity_( - contiguity.empty() - ? std::vector(noBroadcasts(root_domain_).size(), false) - : std::move(contiguity)) { + contiguity.empty() ? getContiguityFilledWith(root_domain_, false) + : std::move(contiguity)) { TORCH_CHECK( - contiguity_.size() == noBroadcasts(getMaybeRFactorDomain()).size(), + contiguity_.size() == getMaybeRFactorDomain().size(), "Invalid contiguity information provided, incorrect size. Received vector of size ", contiguity_.size(), " but needed one of size ", root_domain_.size()); + for (auto i : c10::irange(contiguity_.size())) { + TORCH_CHECK( + getMaybeRFactorDomain().at(i)->isBroadcast() != + contiguity_.at(i).has_value(), + "The contiguity of a broadcast dimension must be None. " + "The contiguity of a non-broadcast dimension must be true/false"); + } std::vector domain_vals(domain_.begin(), domain_.end()); auto inps = IterVisitor::getInputsTo(domain_vals); @@ -2257,21 +2269,27 @@ TensorDomain::TensorDomain( std::vector root_domain, std::vector rfactor_domain, std::vector domain, - std::vector contiguity) + std::vector> contiguity) : Val(passkey, ValType::TensorDomain, DataType::Null), root_domain_(std::move(root_domain)), domain_(std::move(domain)), rfactor_domain_(std::move(rfactor_domain)), contiguity_( - contiguity.empty() - ? std::vector(noBroadcasts(rfactor_domain_).size(), false) - : std::move(contiguity)) { + contiguity.empty() ? getContiguityFilledWith(rfactor_domain_, false) + : std::move(contiguity)) { TORCH_CHECK( - contiguity_.size() == noBroadcasts(getMaybeRFactorDomain()).size(), + contiguity_.size() == getMaybeRFactorDomain().size(), "Invalid contiguity information provided, incorrect size. Received vector of size ", contiguity_.size(), " but needed one of size ", getMaybeRFactorDomain().size()); + for (auto i : c10::irange(contiguity_.size())) { + TORCH_CHECK( + getMaybeRFactorDomain().at(i)->isBroadcast() != + contiguity_.at(i).has_value(), + "The contiguity of a broadcast dimension must be None. " + "The contiguity of a non-broadcast dimension must be true/false"); + } auto inps = IterVisitor::getInputsTo( std::vector(domain_.begin(), domain_.end())); @@ -2405,11 +2423,18 @@ std::string TensorDomain::toInlineString(int indent_size) const { return toString(indent_size); } -void TensorDomain::setContiguity(const std::vector& contig) { +void TensorDomain::setContiguity( + const std::vector>& contig) { TORCH_INTERNAL_ASSERT( - noBroadcasts(getMaybeRFactorDomain()).size() == contig.size(), - "Invalid contiguity vector: ", - contig); + getMaybeRFactorDomain().size() == contig.size(), + "Invalid size of contiguity vector"); + for (auto i : c10::irange(contig.size())) { + TORCH_CHECK( + getMaybeRFactorDomain().at(i)->isBroadcast() != + contig.at(i).has_value(), + "The contiguity of a broadcast dimension must be None. " + "The contiguity of a non-broadcast dimension must be true/false"); + } contiguity_ = contig; } @@ -2668,14 +2693,17 @@ std::vector TensorDomain::noBroadcasts( return noBroadcastDomain; } -std::vector TensorDomain::getContiguousContiguity( - const std::vector& rfactor_domain) { - std::vector contiguity; +std::vector> TensorDomain::getContiguityFilledWith( + const std::vector& rfactor_domain, + bool fill_value) { + std::vector> contiguity; + contiguity.reserve(rfactor_domain.size()); for (auto id : rfactor_domain) { if (id->isBroadcast()) { - continue; + contiguity.push_back(c10::nullopt); + } else { + contiguity.push_back(fill_value); } - contiguity.push_back(true); } return contiguity; } @@ -2768,7 +2796,7 @@ TensorDomain* TensorDomain::flatten(int64_t start_dim, int64_t end_dim) { new_root_domain, rfactor_domain, rfactor_domain, - TensorDomain::getContiguousContiguity(rfactor_domain)); + TensorDomain::getContiguityFilledWith(rfactor_domain, true)); } // TODO: Rfactor a Welford diff --git a/third_party/nvfuser/csrc/ir_utils.cpp b/third_party/nvfuser/csrc/ir_utils.cpp index fcaf317d0fec..d2bd18a63d3b 100644 --- a/third_party/nvfuser/csrc/ir_utils.cpp +++ b/third_party/nvfuser/csrc/ir_utils.cpp @@ -799,17 +799,5 @@ std::string varName(const Val* val) { return name.str(); } -std::vector fullToNoBroadcastMap(const std::vector ids) { - std::vector full2nob_map( - ids.size(), std::numeric_limits::max()); - size_t no_broadcast_i = 0; - for (const auto i : c10::irange(ids.size())) { - if (!ids.at(i)->isBroadcast()) { - full2nob_map.at(i) = no_broadcast_i++; - } - } - return full2nob_map; -} - } // namespace ir_utils } // namespace nvfuser diff --git a/third_party/nvfuser/csrc/ir_utils.h b/third_party/nvfuser/csrc/ir_utils.h index 99fded08417d..929619c66c4e 100644 --- a/third_party/nvfuser/csrc/ir_utils.h +++ b/third_party/nvfuser/csrc/ir_utils.h @@ -364,11 +364,5 @@ TORCH_CUDA_CU_API bool isTorchGatherLookupTv(const Val* tv); TORCH_CUDA_CU_API std::string varName(const Val* val); -// Given ids that may or may not contain broadcast IDs, find the mapping of -// indices between ids and TensorDomain::noBroadcasts(ids). For example, if ids -// is [I, I, b, b, I], then the result should be [0, 1, undefined, undefined, 2] -// -std::vector fullToNoBroadcastMap(const std::vector ids); - } // namespace ir_utils } // namespace nvfuser diff --git a/third_party/nvfuser/csrc/lower_misaligned_vectorization.cpp b/third_party/nvfuser/csrc/lower_misaligned_vectorization.cpp index 09256c582981..b83208866567 100644 --- a/third_party/nvfuser/csrc/lower_misaligned_vectorization.cpp +++ b/third_party/nvfuser/csrc/lower_misaligned_vectorization.cpp @@ -488,11 +488,6 @@ class MisalignedVectorizationModifier : public kir::ExprMutator { const auto& producer_root_domain = producer_tv->getMaybeRFactorDomain(); const auto& consumer_root_domain = consumer_tv->getMaybeRFactorDomain(); - auto consumer_full2nob_map = - ir_utils::fullToNoBroadcastMap(consumer_root_domain); - auto producer_full2nob_map = - ir_utils::fullToNoBroadcastMap(producer_root_domain); - // Calculate extent of merged root domains Val* extent = nullptr; auto consumer_root_idx = int(consumer_root_domain.size()) - 1; @@ -537,8 +532,7 @@ class MisalignedVectorizationModifier : public kir::ExprMutator { // If it's not contiguous, extending the vectorization domain // further is not possible - if (!(producer_contig.at(producer_full2nob_map.at(i)) && - consumer_contig.at(producer_full2nob_map.at(consumer_root_idx)))) { + if (!(*producer_contig.at(i) && *consumer_contig.at(consumer_root_idx))) { break; } diff --git a/third_party/nvfuser/csrc/lower_utils.cpp b/third_party/nvfuser/csrc/lower_utils.cpp index 34b8c895dfc1..e8600af2d95c 100644 --- a/third_party/nvfuser/csrc/lower_utils.cpp +++ b/third_party/nvfuser/csrc/lower_utils.cpp @@ -55,18 +55,18 @@ ir_utils::TVDomainGuard overrideContiguityGuard( // Use domain guard to ignore the contiguity of // consumer tv. TensorDomain* domain_with_specified_contiguity = nullptr; - std::vector contiguity_vector( - TensorDomain::noBroadcasts(tv->getMaybeRFactorDomain()).size(), - contiguity); if (tv->hasRFactor()) { domain_with_specified_contiguity = IrBuilder::create( tv->getRootDomain(), tv->getRFactorDomain(), tv->domain()->domain(), - contiguity_vector); + TensorDomain::getContiguityFilledWith( + tv->getRFactorDomain(), contiguity)); } else { domain_with_specified_contiguity = IrBuilder::create( - tv->getRootDomain(), tv->domain()->domain(), contiguity_vector); + tv->getRootDomain(), + tv->domain()->domain(), + TensorDomain::getContiguityFilledWith(tv->getRootDomain(), contiguity)); } return ir_utils::TVDomainGuard(tv, domain_with_specified_contiguity); diff --git a/third_party/nvfuser/csrc/lower_validation.cpp b/third_party/nvfuser/csrc/lower_validation.cpp index 1bd8f0ddc878..352ab78dab14 100644 --- a/third_party/nvfuser/csrc/lower_validation.cpp +++ b/third_party/nvfuser/csrc/lower_validation.cpp @@ -189,7 +189,6 @@ void checkContiguity( TensorView* tv) { TORCH_INTERNAL_ASSERT(tv->getMemoryType() == MemoryType::Global); - int no_broadcast_i = 0; for (const auto idx : c10::irange(tv->getRootDomain().size())) { auto root = tv->getRootDomain()[idx]; if (domains.find(root) != domains.end()) { @@ -199,14 +198,11 @@ void checkContiguity( "Issue found in, ", tv); TORCH_INTERNAL_ASSERT( - tv->domain()->contiguity().at(no_broadcast_i), + *tv->domain()->contiguity().at(idx), "Cannot merge non-contiguous root domains with misaligned vectorization.", "Issue found in, ", tv); } - if (!root->isBroadcast()) { - no_broadcast_i++; - } } } @@ -226,20 +222,17 @@ void checkContiguity( PairwiseRootDomainMap(producer, consumer) .mapConsumerToProducer(consumer->domain(), producer->domain()); - std::unordered_map producer_domain_contiguity; - int no_broadcast_i = 0; + std::unordered_map> + producer_domain_contiguity; for (const auto idx : c10::irange(producer->getMaybeRFactorDomain().size())) { - auto root = producer->getMaybeRFactorDomain()[idx]; - auto contiguity = producer->domain()->contiguity().at(no_broadcast_i); + auto root = producer->getMaybeRFactorDomain().at(idx); + auto contiguity = producer->domain()->contiguity().at(idx); producer_domain_contiguity.insert({root, contiguity}); - if (!root->isBroadcast()) { - no_broadcast_i++; - } } for (auto consumer_root : consumer->getMaybeRFactorDomain()) { if (domains.find(consumer_root) != domains.end()) { - auto producer_root = root_c2p[consumer_root]; + auto producer_root = root_c2p.at(consumer_root); TORCH_INTERNAL_ASSERT( producer_domain_contiguity.find(producer_root) != producer_domain_contiguity.end()); @@ -253,7 +246,7 @@ void checkContiguity( TORCH_INTERNAL_ASSERT(root_c2p.find(consumer_root) != root_c2p.end()); TORCH_INTERNAL_ASSERT( - producer_domain_contiguity[producer_root], + *producer_domain_contiguity.at(producer_root), "Cannot merge non-contiguous root domains with misaligned vectorization.", "Issue found in, ", consumer); @@ -304,21 +297,19 @@ class VectorizeValidator : public OptInDispatch { // For the producer tensor, it's indexed first by transformed like // the consumer. So, to find its contig merged domain, use the // consumer TensorDomain with the producer contiguity info. - static std::vector mapProducerContiguity( + static std::vector> mapProducerContiguity( TensorView* producer_tv, TensorView* consumer_tv) { const auto c2p = PairwiseRootDomainMap(producer_tv, consumer_tv) .mapConsumerToProducer( consumer_tv->domain(), producer_tv->domain()); - std::vector producer_contiguity; - - auto producer_full2nob = - ir_utils::fullToNoBroadcastMap(producer_tv->getMaybeRFactorDomain()); + std::vector> producer_contiguity; for (auto consumer_root_id : consumer_tv->getRootDomain()) { auto producer_root_id = c2p.at(consumer_root_id); if (producer_root_id->isBroadcast()) { + producer_contiguity.push_back(c10::nullopt); continue; } auto producer_root_it = std::find( @@ -329,8 +320,8 @@ class VectorizeValidator : public OptInDispatch { producer_root_it != producer_tv->getMaybeRFactorDomain().end()); auto producer_root_id_offset = std::distance( producer_tv->getMaybeRFactorDomain().begin(), producer_root_it); - producer_contiguity.push_back(producer_tv->domain()->contiguity().at( - producer_full2nob.at(producer_root_id_offset))); + producer_contiguity.push_back( + producer_tv->domain()->contiguity().at(producer_root_id_offset)); } return producer_contiguity; @@ -425,12 +416,14 @@ class VectorizeValidator : public OptInDispatch { // Contiguity is based on rfactor domain. IterDomain* last_root_dim = nullptr; + size_t last_root_dim_pos; for (size_t i = tv->getMaybeRFactorDomain().size(); i > 0; i--) { auto r_id = tv->getMaybeRFactorDomain()[i - 1]; if (r_id->isReduction() || r_id->isBroadcast()) { continue; } last_root_dim = r_id; + last_root_dim_pos = i - 1; break; } @@ -442,7 +435,7 @@ class VectorizeValidator : public OptInDispatch { TORCH_CHECK( last_root_dim == validator.vectorized_id_ && - tv->domain()->contiguity().back(), + *tv->domain()->contiguity().at(last_root_dim_pos), "Vectorized dim has to be from a contiguous inner most position: ", tv, "\n"); diff --git a/third_party/nvfuser/csrc/multidevice/multidevice_runtime.cpp b/third_party/nvfuser/csrc/multidevice/multidevice_runtime.cpp index b5a536efccf1..7f909a362ddd 100644 --- a/third_party/nvfuser/csrc/multidevice/multidevice_runtime.cpp +++ b/third_party/nvfuser/csrc/multidevice/multidevice_runtime.cpp @@ -78,7 +78,8 @@ MultiDeviceRuntime::CompiledKernelPtr MultiDeviceRuntime::compileCluster( args.setDeviceIndex(device_index); // Lower the fusion and compile the generated kernel. - executor_ptr->compileFusion(fusion_from_cluster.get(), args, launch_params); + executor_ptr->compileFusion( + fusion_from_cluster.get(), args, launch_params, {}); return executor_ptr; } diff --git a/third_party/nvfuser/csrc/ops/alias.cpp b/third_party/nvfuser/csrc/ops/alias.cpp index 92c2c307b6e9..ef4fa61a947e 100644 --- a/third_party/nvfuser/csrc/ops/alias.cpp +++ b/third_party/nvfuser/csrc/ops/alias.cpp @@ -167,7 +167,7 @@ TensorView* squeeze(TensorView* x, const std::vector& to_squeeze) { auto out = IrBuilder::create( IrBuilder::create( - out_domain, TensorDomain::getContiguousContiguity(out_domain)), + out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)), *x->getDataType()); IrBuilder::create(x->container(), out, x, to_squeeze); @@ -316,7 +316,7 @@ TensorView* permute(TensorView* x, const std::vector& new2old) { TensorView* out_tensor = IrBuilder::create( IrBuilder::create( - out_domain, TensorDomain::getContiguousContiguity(out_domain)), + out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)), x->getDataType().value()); IrBuilder::create(out_tensor, x, normalized_new2old); return out_tensor; diff --git a/third_party/nvfuser/csrc/ops/arith.cpp b/third_party/nvfuser/csrc/ops/arith.cpp index accb7f25562d..08a8fdba7ef9 100644 --- a/third_party/nvfuser/csrc/ops/arith.cpp +++ b/third_party/nvfuser/csrc/ops/arith.cpp @@ -120,7 +120,7 @@ TensorView* select(TensorView* tv, int dim, Int* index) { } auto td = IrBuilder::create( - new_root, TensorDomain::getContiguousContiguity(new_root)); + new_root, TensorDomain::getContiguityFilledWith(new_root, true)); auto out = IrBuilder::create(td, *tv->getDataType()); IrBuilder::create(out, tv, dom.at(dim), index); return out; @@ -162,7 +162,7 @@ TensorView* index_select(TensorView* lookup_tv, int dim, TensorView* index_tv) { } auto td = IrBuilder::create( - new_root, TensorDomain::getContiguousContiguity(new_root)); + new_root, TensorDomain::getContiguityFilledWith(new_root, true)); auto out = IrBuilder::create(td, dtype); // broadcast index to lookup's rank. @@ -206,7 +206,7 @@ TensorView* torch_gather(TensorView* inp, int dim, TensorView* index) { TensorView* out_tensor = IrBuilder::create( IrBuilder::create( - out_domain, TensorDomain::getContiguousContiguity(out_domain)), + out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)), inp->getDataType().value()); IrBuilder::create( @@ -254,7 +254,7 @@ TensorView* scatterOp( TensorView* out_tensor = IrBuilder::create( IrBuilder::create( - out_domain, TensorDomain::getContiguousContiguity(out_domain)), + out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)), self->getDataType().value()); IrBuilder::create( @@ -273,12 +273,10 @@ TensorView* scatter( // TENSOR FACTORIES TensorView* rand(const std::vector& shape, DataType dtype) { auto n = shape.size(); - auto n_nob = std::count_if( - shape.begin(), shape.end(), [](auto x) { return !x->isOneInt(); }); auto out = TensorViewBuilder() .ndims(n) .dtype(dtype) - .contiguity(std::vector(n_nob, true)) + .contiguity(true) .shape(shape) .build(); IrBuilder::create(RNGOpType::Uniform, out, dtype); @@ -292,12 +290,10 @@ TensorView* uniform( Val* high, DataType dtype) { auto n = shape.size(); - auto n_nob = std::count_if( - shape.begin(), shape.end(), [](auto x) { return !x->isOneInt(); }); auto out = TensorViewBuilder() .ndims(n) .dtype(dtype) - .contiguity(std::vector(n_nob, true)) + .contiguity(true) .shape(shape) .build(); IrBuilder::create( @@ -311,12 +307,10 @@ TensorView* normal( Val* std, DataType dtype) { auto n = shape.size(); - auto n_nob = std::count_if( - shape.begin(), shape.end(), [](auto x) { return !x->isOneInt(); }); auto out = TensorViewBuilder() .ndims(n) .dtype(dtype) - .contiguity(std::vector(n_nob, true)) + .contiguity(true) .shape(shape) .build(); IrBuilder::create( @@ -326,12 +320,10 @@ TensorView* normal( TensorView* randn(const std::vector& shape, DataType dtype) { auto n = shape.size(); - auto n_nob = std::count_if( - shape.begin(), shape.end(), [](auto x) { return !x->isOneInt(); }); auto out = TensorViewBuilder() .ndims(n) .dtype(dtype) - .contiguity(std::vector(n_nob, true)) + .contiguity(true) .shape(shape) .build(); IrBuilder::create(RNGOpType::NormalStandard, out, dtype); @@ -382,12 +374,10 @@ TensorView* full( fill_value = castOp(dtype, fill_value); } auto n = shape.size(); - auto n_nob = std::count_if( - shape.begin(), shape.end(), [](auto x) { return !x->isOneInt(); }); auto out = TensorViewBuilder() .ndims(n) .dtype(dtype) - .contiguity(std::vector(n_nob, true)) + .contiguity(true) .shape(shape) .build(); IrBuilder::create(out, fill_value); @@ -469,12 +459,10 @@ TensorView* iota(Val* length, Val* start, Val* step, DataType dtype) { if (step->getDataType() != dtype) { step = castOp(dtype, step); } - auto contiguity = - length->isOneInt() ? std::vector{} : std::vector{true}; auto out = TensorViewBuilder() .ndims(1) .dtype(dtype) - .contiguity(contiguity) + .contiguity(true) .shape({length}) .build(); IrBuilder::create(out, length, start, step); @@ -535,16 +523,10 @@ TensorView* arange(Val* start, Val* end, Val* step, DataType dtype) { TensorView* eye(Val* rows, Val* cols, DataType dtype) { TORCH_CHECK(rows->getDataType() == DataType::Int, "rows must have type Int"); TORCH_CHECK(cols->getDataType() == DataType::Int, "cols must have type Int"); - std::vector contiguity; - for (auto len : {rows, cols}) { - if (!len->isOneInt()) { - contiguity.push_back(true); - } - } auto out = TensorViewBuilder() .ndims(2) .dtype(dtype) - .contiguity(contiguity) + .contiguity(true) .shape(std::vector{rows, cols}) .build(); IrBuilder::create(out, dtype); @@ -1106,7 +1088,7 @@ static TensorView* newForReduction( } TensorDomain* td = IrBuilder::create( - new_domain, TensorDomain::getContiguousContiguity(new_domain)); + new_domain, TensorDomain::getContiguityFilledWith(new_domain, true)); data_type = data_type == DataType::Null ? tv->getDataType().value() : data_type; @@ -1228,7 +1210,7 @@ TensorView* maybeFullInsteadOfReduction( } TensorDomain* td = IrBuilder::create( - new_root, TensorDomain::getContiguousContiguity(new_root)); + new_root, TensorDomain::getContiguityFilledWith(new_root, true)); dtype = (dtype == DataType::Null ? tv->getDataType().value() : dtype); auto output = IrBuilder::create(td, dtype); @@ -1447,7 +1429,7 @@ TensorView* broadcast( TensorView* out_tensor = IrBuilder::create( IrBuilder::create( - out_domain, TensorDomain::getContiguousContiguity(out_domain)), + out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)), inp->getDataType().value()); IrBuilder::create(out_tensor, inp, is_broadcast_dim); return out_tensor; @@ -1517,7 +1499,7 @@ TensorView* expand(TensorView* inp, const std::vector& expanded_sizes) { TensorView* out_tensor = IrBuilder::create( IrBuilder::create( - out_domain, TensorDomain::getContiguousContiguity(out_domain)), + out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)), inp->getDataType().value()); if (!expanded) { IrBuilder::create(UnaryOpType::Set, out_tensor, inp); @@ -1577,7 +1559,7 @@ TensorView* expand_as(TensorView* inp, TensorView* other) { TensorView* out_tensor = IrBuilder::create( IrBuilder::create( - out_domain, TensorDomain::getContiguousContiguity(out_domain)), + out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)), inp->getDataType().value()); if (!expanded) { IrBuilder::create(UnaryOpType::Set, out_tensor, inp); @@ -2238,7 +2220,7 @@ TensorView* shift( out = IrBuilder::create( IrBuilder::create( - out_dom, TensorDomain::getContiguousContiguity(out_dom)), + out_dom, TensorDomain::getContiguityFilledWith(out_dom, true)), inp->getDataType().value()); IrBuilder::create(out, inp, offsets, pad_width); @@ -2262,7 +2244,8 @@ TensorDomain* generateTensorDomainWithStrides( std::all_of( strides.begin(), strides.end(), [](int s) { return s == 1; }))) { return IrBuilder::create( - root_domains, TensorDomain::getContiguousContiguity(root_domains)); + root_domains, + TensorDomain::getContiguityFilledWith(root_domains, true)); } for (const auto i : c10::irange(root_domains.size())) { @@ -2283,7 +2266,7 @@ TensorDomain* generateTensorDomainWithStrides( root_domains, strided_domains, strided_domains, - TensorDomain::getContiguousContiguity(strided_domains)); + TensorDomain::getContiguityFilledWith(strided_domains, true)); return strided_td; } @@ -2421,7 +2404,7 @@ TensorView* viewAsScalar(TensorView* inp) { auto out = IrBuilder::create( inp->container(), IrBuilder::create( - out_domain, TensorDomain::getContiguousContiguity(out_domain)), + out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)), out_type); IrBuilder::create(inp->container(), out, inp, id); @@ -2490,7 +2473,7 @@ static TensorView* newForMma( } TensorDomain* td = IrBuilder::create( - new_domain, TensorDomain::getContiguousContiguity(new_domain)); + new_domain, TensorDomain::getContiguityFilledWith(new_domain, true)); return IrBuilder::create(td, data_type); } diff --git a/third_party/nvfuser/csrc/ops/utils.cpp b/third_party/nvfuser/csrc/ops/utils.cpp index e2edc559fdf8..ba9e563aec44 100644 --- a/third_party/nvfuser/csrc/ops/utils.cpp +++ b/third_party/nvfuser/csrc/ops/utils.cpp @@ -274,7 +274,7 @@ TensorView* newOutputTV(const std::vector& vals, DataType dtype) { return IrBuilder::create( IrBuilder::create( - out_domain, TensorDomain::getContiguousContiguity(out_domain)), + out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)), dtype); } diff --git a/third_party/nvfuser/csrc/python_frontend/fusion_record.h b/third_party/nvfuser/csrc/python_frontend/fusion_record.h index c8ead8da945e..ff4b041c9150 100644 --- a/third_party/nvfuser/csrc/python_frontend/fusion_record.h +++ b/third_party/nvfuser/csrc/python_frontend/fusion_record.h @@ -1105,42 +1105,22 @@ struct TensorRecord : RecordFunctor { } virtual void operator()(FusionState& fd) final { - // auto tv = TensorViewBuilder() - // .ndims(symbolic_sizes_.size()) - // .contiguity(contiguous_info_) - // .shape(symbolic_sizes_) - // .dtype(dtype_) - // .build(); - std::vector sizes; - std::vector contig_info; int rank = symbolic_sizes_.size(); + std::vector is_expand(rank); for (const auto index : c10::irange(rank)) { - if (!contiguous_info_[index].has_value()) { - auto builder = IterDomainBuilder( - FusionGuard::getCurFusion()->zeroVal(), - FusionGuard::getCurFusion()->oneVal()) - .iter_type(IterType::Broadcast); - if (symbolic_sizes_[index] == 1) { - sizes.push_back(builder.build()); - } else if (symbolic_sizes_[index] == -1) { - sizes.push_back( - builder.expanded_extent(IrBuilder::create()).build()); - } else { - TORCH_INTERNAL_ASSERT( - false, "static shape in Tensor is not implemented yet"); - } - } else { - sizes.push_back(IterDomainBuilder( - FusionGuard::getCurFusion()->zeroVal(), - IrBuilder::create()) - .build()); - contig_info.push_back(contiguous_info_[index].value()); - } - } - - auto tv = IrBuilder::create( - IrBuilder::create(sizes, contig_info), dtype_); + bool is_broadcast = !contiguous_info_[index].has_value(); + bool has_symbolic_size = (symbolic_sizes_[index] == -1); + is_expand[index] = is_broadcast && has_symbolic_size; + } + + auto tv = TensorViewBuilder() + .ndims(symbolic_sizes_.size()) + .contiguity(contiguous_info_) + .shape(symbolic_sizes_) + .dtype(dtype_) + .expanded(std::move(is_expand)) + .build(); if (symbolic_sizes_.empty() && is_cpu_) { tv->setCpuScalar(true); diff --git a/third_party/nvfuser/csrc/scheduler/registry.cpp b/third_party/nvfuser/csrc/scheduler/registry.cpp index 7fbf4d4d64a5..1cb05e93767f 100644 --- a/third_party/nvfuser/csrc/scheduler/registry.cpp +++ b/third_party/nvfuser/csrc/scheduler/registry.cpp @@ -859,13 +859,11 @@ size_t SchedulerRuntimeInfo::getMaxVectorizableWidth(TensorView* tv) { auto contiguity = tv->domain()->contiguity(); // Appears after reductions the reduction domain often has a contiguity entry. // This only matters if the result of the reduction is an output - auto tv_root_nob = TensorDomain::noBroadcasts(tv_root); - auto tv_root_norb = TensorDomain::noBroadcasts(tv_root_no_reductions); - if (contiguity.size() == tv_root_nob.size() && - contiguity.size() != tv_root_norb.size()) { - std::vector new_contiguity; - for (auto i : c10::irange(tv_root_nob.size())) { - if (!tv_root_nob[i]->isReduction()) { + if (contiguity.size() == tv_root.size() && + contiguity.size() != tv_root_no_reductions.size()) { + std::vector> new_contiguity; + for (auto i : c10::irange(tv_root.size())) { + if (!tv_root[i]->isReduction()) { new_contiguity.push_back(contiguity[i]); } } @@ -873,15 +871,15 @@ size_t SchedulerRuntimeInfo::getMaxVectorizableWidth(TensorView* tv) { } tv_root = tv_root_no_reductions; - auto tv_root_nob_size = tv_root_nob.size(); + auto tv_root_size = tv_root.size(); // Filter out 0-dim tensors - if (tv_root_nob_size < 1) { + if (tv_root_size < 1) { return 1; } // Filter out mismatched contiguity info - if (tv_root_nob_size != contiguity.size()) { + if (tv_root_size != contiguity.size()) { return 1; } @@ -897,16 +895,16 @@ size_t SchedulerRuntimeInfo::getMaxVectorizableWidth(TensorView* tv) { } size_t numel = 1; - for (auto i : c10::irange(tv_root_nob_size)) { - auto root_i = tv_root_nob_size - i - 1; - auto root_id = tv_root_nob[root_i]; + for (auto i : c10::irange(tv_root_size)) { + auto root_i = tv_root_size - i - 1; + auto root_id = tv_root[root_i]; if (root_id->extent()->isOneInt() || root_id->isBroadcast()) { continue; } // Not contiguous - if (!contiguity[root_i]) { + if (!*contiguity[root_i]) { break; } @@ -953,17 +951,15 @@ size_t SchedulerRuntimeInfo::getInnerDimVectorizableWidth(TensorView* tv) { : tv->getMaybeRFactorDomain(); auto tv_root_no_reductions = TensorDomain::noReductions(tv_root); - auto tv_root_nob = TensorDomain::noBroadcasts(tv_root); - auto tv_root_norb = TensorDomain::noBroadcasts(tv_root_no_reductions); auto contiguity = tv->domain()->contiguity(); // Appears after reductions the reduction domain often has a contiguity entry. // This only matters if the result of the reduction is an output - if (contiguity.size() == tv_root_nob.size() && - contiguity.size() != tv_root_norb.size()) { - std::vector new_contiguity; - for (auto i : c10::irange(tv_root_nob.size())) { - if (!tv_root_nob[i]->isReduction()) { + if (contiguity.size() == tv_root.size() && + contiguity.size() != tv_root_no_reductions.size()) { + std::vector> new_contiguity; + for (auto i : c10::irange(tv_root.size())) { + if (!tv_root[i]->isReduction()) { new_contiguity.push_back(contiguity[i]); } } @@ -971,23 +967,23 @@ size_t SchedulerRuntimeInfo::getInnerDimVectorizableWidth(TensorView* tv) { } tv_root = tv_root_no_reductions; - auto tv_root_norb_size = tv_root_norb.size(); + auto tv_root_no_reductions_size = tv_root_no_reductions.size(); // Filter out 0-dim tensors - if (tv_root_norb_size < 1) { + if (tv_root_no_reductions_size < 1) { return 1; } // Filter out mismatched contiguity info - if (tv_root_norb_size != contiguity.size()) { + if (tv_root_no_reductions_size != contiguity.size()) { return 1; } auto inner_most_dim = scheduler_utils::innerMostRootDim(tv); int id_pos = -1; - for (auto root_i : c10::irange(tv_root_norb_size)) { - if (tv_root_norb[root_i] == inner_most_dim) { + for (auto root_i : c10::irange(tv_root_no_reductions_size)) { + if (tv_root_no_reductions[root_i] == inner_most_dim) { id_pos = root_i; break; } @@ -1000,7 +996,7 @@ size_t SchedulerRuntimeInfo::getInnerDimVectorizableWidth(TensorView* tv) { } // If the inner most dimension is not contiguous return 1 - if (!contiguity[id_pos]) { + if (!*contiguity[id_pos]) { return 1; } diff --git a/third_party/nvfuser/csrc/scheduler/utils.cpp b/third_party/nvfuser/csrc/scheduler/utils.cpp index 50c04df55b3c..f3bc4a865213 100644 --- a/third_party/nvfuser/csrc/scheduler/utils.cpp +++ b/third_party/nvfuser/csrc/scheduler/utils.cpp @@ -1256,26 +1256,25 @@ bool hasInnerDim( return true; } - auto rfactor_dom_nob = - TensorDomain::noBroadcasts(tv->getMaybeRFactorDomain()); + auto rfactor_dom = tv->getMaybeRFactorDomain(); auto root_pos_it = std::find_if( - rfactor_dom_nob.begin(), - rfactor_dom_nob.end(), + rfactor_dom.begin(), + rfactor_dom.end(), [&inner_most_dim](IterDomain* id) { return inner_most_dim == id; }); - if (root_pos_it == rfactor_dom_nob.end()) { + if (root_pos_it == rfactor_dom.end()) { return false; } - auto inner_most_dim_pos = std::distance(rfactor_dom_nob.begin(), root_pos_it); + auto inner_most_dim_pos = std::distance(rfactor_dom.begin(), root_pos_it); const auto& contiguity = tv->domain()->contiguity(); - TORCH_INTERNAL_ASSERT(contiguity.size() == rfactor_dom_nob.size()); + TORCH_INTERNAL_ASSERT(contiguity.size() == rfactor_dom.size()); // Don't vectorize if inner most dimension is not contiguous - if (!contiguity[inner_most_dim_pos]) { + if (!*contiguity[inner_most_dim_pos]) { return false; } diff --git a/third_party/nvfuser/csrc/scheduler/vectorize_helper.cpp b/third_party/nvfuser/csrc/scheduler/vectorize_helper.cpp index 2bc1418cf0ad..480eab76c264 100644 --- a/third_party/nvfuser/csrc/scheduler/vectorize_helper.cpp +++ b/third_party/nvfuser/csrc/scheduler/vectorize_helper.cpp @@ -1027,32 +1027,30 @@ std::vector> getContigVectorSizesOf( : mapper.mappedRFactorIds(of_tv); auto of_tv_root_no_reductions = TensorDomain::noReductions(of_tv_root); - auto of_tv_root_nob = TensorDomain::noBroadcasts(of_tv_root); - auto of_tv_root_norb = TensorDomain::noBroadcasts(of_tv_root_no_reductions); - auto contiguity = of_tv->domain()->contiguity(); // Appears after reductions the reduction domain often has a contiguity entry. // This only matters if the result of the reduction is an output - if (contiguity.size() == of_tv_root_nob.size() && - contiguity.size() != of_tv_root_norb.size()) { - std::vector new_contiguity; - for (auto i : c10::irange(of_tv_root_nob.size())) { - if (!of_tv_root_nob[i]->isReduction()) { + if (contiguity.size() == of_tv_root.size() && + contiguity.size() != of_tv_root_no_reductions.size()) { + std::vector> new_contiguity; + for (auto i : c10::irange(of_tv_root.size())) { + if (!of_tv_root[i]->isReduction()) { new_contiguity.push_back(contiguity[i]); } } contiguity = new_contiguity; } - auto of_tv_root_norb_size = of_tv_root_norb.size(); + auto of_tv_root_no_reductions_size = of_tv_root_no_reductions.size(); // Filter out 0-dim tensors - if (of_tv_root_norb_size < 1) { + if (of_tv_root_no_reductions_size < 1) { return {}; } TORCH_INTERNAL_ASSERT( - of_tv_root_norb_size == contiguity.size(), "Contiguity mismatch found."); + of_tv_root_no_reductions_size == contiguity.size(), + "Contiguity mismatch found."); std::vector> vectorizable_dim_sizes; @@ -1063,12 +1061,12 @@ std::vector> getContigVectorSizesOf( // vectorize dimension. size_t projected_dims_i = projected_dims.size(); - for (auto i : c10::irange(of_tv_root_norb_size)) { + for (auto i : c10::irange(of_tv_root_no_reductions_size)) { if (projected_dims_i == 0) { break; } - auto root_i = of_tv_root_norb_size - i - 1; - auto root_id = of_tv_root_norb.at(root_i); + auto root_i = of_tv_root_no_reductions_size - i - 1; + auto root_id = of_tv_root_no_reductions.at(root_i); if (root_id->extent()->isOneInt() || root_id->isBroadcast()) { if (projected_dims[projected_dims_i - 1]->sameAs(root_id)) { @@ -1078,7 +1076,7 @@ std::vector> getContigVectorSizesOf( } // Not contiguous - if (!contiguity[root_i]) { + if (!*contiguity[root_i]) { break; } diff --git a/third_party/nvfuser/csrc/tensor_view.cpp b/third_party/nvfuser/csrc/tensor_view.cpp index 7d1084e93a0d..5a00cfed7d01 100644 --- a/third_party/nvfuser/csrc/tensor_view.cpp +++ b/third_party/nvfuser/csrc/tensor_view.cpp @@ -110,8 +110,8 @@ TensorView::TensorView( } // default to non_contiguous; - std::vector contig_info( - TensorDomain::noBroadcasts(sizes).size(), false); + std::vector> contig_info = + TensorDomain::getContiguityFilledWith(sizes, false); int64_t inner_most_non_broadcast = tensor_type->dim().value() - 1; while (inner_most_non_broadcast >= 0) { @@ -123,8 +123,6 @@ TensorView::TensorView( } // if all broadcast, then inner_most_non_broadcast == -1 - auto full2nob = ir_utils::fullToNoBroadcastMap(sizes); - // we iterate through stride_index_, which goes from fastest changing // dimension to slowest, instead of iterating through sizes. This allows // easier contiguity check; @@ -148,8 +146,7 @@ TensorView::TensorView( if (!found_innermost_non_broadcast) { // mark fastest changing dimension collapsible only when it's // "innermost" - contig_info.at(full2nob.at(index)) = - ((int64_t)index == inner_most_non_broadcast); + contig_info.at(index) = ((int64_t)index == inner_most_non_broadcast); } else { // check the neighboring faster dimension, collapse if it is considered // as inner dimension per stride_index @@ -159,8 +156,7 @@ TensorView::TensorView( if (inner_index_opt.has_value() && inner_index_opt.value() == (index + 1)) { // collapse if inner dimension has non-broadcasted strides - contig_info.at(full2nob.at(index)) = - !sizes.at(index + 1)->isBroadcast(); + contig_info.at(index) = !sizes.at(index + 1)->isBroadcast(); } } } @@ -274,8 +270,7 @@ void TensorView::convertRfactorToRootDomain() { } TORCH_INTERNAL_ASSERT( - TensorDomain::noBroadcasts(new_root_domain).size() == - domain()->contiguity().size()); + new_root_domain.size() == domain()->contiguity().size()); setDomain(IrBuilder::create( container(), new_root_domain, domain()->contiguity())); }; @@ -1076,8 +1071,7 @@ TensorView* TensorView::multiOutputRfactorHelper( new_id.push_back(replay.getReplay().at(id)); } - std::vector new_contig( - tv->domain()->contiguity().begin(), tv->domain()->contiguity().end()); + std::vector> new_contig(tv->domain()->contiguity()); // replace tensor domain of target tv tv->setDomain(IrBuilder::create( tv->getRootDomain(), new_id, new_contig)); @@ -1259,7 +1253,7 @@ TensorView* TensorView::cacheBefore(c10::optional cache_op) { consumer->setDomain(IrBuilder::create( container(), new_root_domain, - TensorDomain::getContiguousContiguity(new_root_domain))); + TensorDomain::getContiguityFilledWith(new_root_domain, true))); // Insert producer - Cache_Before (CB) - before this TV. // Before: Prev TV -> [Definition Op] -> This TV @@ -1319,7 +1313,7 @@ TensorView* TensorView::cacheFork() { IrBuilder::create( container(), IterDomain::clone(root_domain), - TensorDomain::getContiguousContiguity(root_domain)), + TensorDomain::getContiguityFilledWith(root_domain, true)), getDataType().value()); // Create write operation from this TV to new output @@ -1391,7 +1385,7 @@ TensorView* TensorView::cacheAfter(c10::optional cache_op) { IrBuilder::create( container(), new_root_domain, - TensorDomain::getContiguousContiguity(new_root_domain)), + TensorDomain::getContiguityFilledWith(new_root_domain, true)), getDataType().value()); // Set domain of producer - No Change @@ -1437,18 +1431,12 @@ void TensorView::clearReductionIterDomains() { "should not call clearReductionIterDomains on already transformed TensorDomains"); std::vector new_root; - std::vector new_contig; - int64_t no_broadcast_i = 0; + std::vector> new_contig; for (const auto i : c10::irange(getRootDomain().size())) { auto root_i = getRootDomain().at(i); if (!root_i->isReduction()) { new_root.push_back(root_i); - if (!root_i->isBroadcast()) { - new_contig.push_back(domain()->contiguity().at(no_broadcast_i)); - } - } - if (!root_i->isBroadcast()) { - no_broadcast_i++; + new_contig.push_back(domain()->contiguity().at(i)); } } @@ -1513,12 +1501,23 @@ TensorViewBuilder& TensorViewBuilder::dtype(DataType dtype) { return *this; } -TensorViewBuilder& TensorViewBuilder::contiguity(std::vector contiguity) { - TORCH_CHECK(contiguity_.empty(), "Attempting to reset contiguity"); +TensorViewBuilder& TensorViewBuilder::contiguity( + std::vector> contiguity) { + TORCH_CHECK( + contiguity_.empty() && !uniform_contiguity_.has_value(), + "Attempting to reset contiguity"); contiguity_ = std::move(contiguity); return *this; } +TensorViewBuilder& TensorViewBuilder::contiguity(bool contiguity) { + TORCH_CHECK( + contiguity_.empty() && !uniform_contiguity_.has_value(), + "Attempting to reset contiguity"); + uniform_contiguity_ = contiguity; + return *this; +} + TensorViewBuilder& TensorViewBuilder::shape(const std::vector& shape) { TORCH_CHECK(shape_.empty(), "Attempting to reset shape"); if (!shape.empty()) { @@ -1568,7 +1567,6 @@ TensorViewBuilder& TensorViewBuilder::expanded(std::vector expanded) { TensorView* TensorViewBuilder::build() const { // Build the domain std::vector domain(ndims_, nullptr); - size_t non_broadcasting_dims = 0; for (const auto i : c10::irange(ndims_)) { bool is_expanded = false; Val* extent = nullptr; @@ -1597,8 +1595,6 @@ TensorView* TensorViewBuilder::build() const { IterDomainBuilder builder(FusionGuard::getCurFusion()->zeroVal(), extent); if (extent->isOneInt()) { builder.iter_type(IterType::Broadcast); - } else { - non_broadcasting_dims++; } if (expanded_extent != nullptr) { builder.expanded_extent(expanded_extent); @@ -1607,12 +1603,32 @@ TensorView* TensorViewBuilder::build() const { } TORCH_CHECK( - contiguity_.empty() || contiguity_.size() == non_broadcasting_dims, + contiguity_.empty() || contiguity_.size() == domain.size(), "The size of contiguity must equal to the number of non-broadcasting IterDomains"); - // Create the final TensorView - return IrBuilder::create( - IrBuilder::create(domain, contiguity_), dtype_); + for (auto i : c10::irange(contiguity_.size())) { + TORCH_CHECK( + domain.at(i)->isBroadcast() != contiguity_.at(i).has_value(), + "The contiguity of a broadcast dimension must be None. " + "The contiguity of a non-broadcast dimension must be true/false"); + } + + if (uniform_contiguity_.has_value()) { + TORCH_INTERNAL_ASSERT( + contiguity_.empty(), + "contiguity_ and uniform_contiguity_ can not be set at the same time"); + // Create the final TensorView + return IrBuilder::create( + IrBuilder::create( + domain, + TensorDomain::getContiguityFilledWith( + domain, *uniform_contiguity_)), + dtype_); + } else { + // Create the final TensorView + return IrBuilder::create( + IrBuilder::create(domain, contiguity_), dtype_); + } } } // namespace nvfuser diff --git a/third_party/nvfuser/csrc/transform_rfactor.cpp b/third_party/nvfuser/csrc/transform_rfactor.cpp index a56145877107..5d7b4fdb2505 100644 --- a/third_party/nvfuser/csrc/transform_rfactor.cpp +++ b/third_party/nvfuser/csrc/transform_rfactor.cpp @@ -406,7 +406,7 @@ std::pair TransformRFactor::runReplay( new_producer_root, new_producer_rfactor_domain, new_producer_domain, - TensorDomain::getContiguousContiguity(new_producer_rfactor_domain)); + TensorDomain::getContiguityFilledWith(new_producer_rfactor_domain, true)); // Producer has been finished, now work on consumer. @@ -464,7 +464,7 @@ std::pair TransformRFactor::runReplay( original_td->container(), new_consumer_root_domain, new_consumer_domain, - TensorDomain::getContiguousContiguity(new_consumer_root_domain)); + TensorDomain::getContiguityFilledWith(new_consumer_root_domain, true)); return std::make_pair(producer_domain, consumer_domain); } diff --git a/third_party/nvfuser/csrc/transform_view.cpp b/third_party/nvfuser/csrc/transform_view.cpp index 0214560b326c..220b174e0da9 100644 --- a/third_party/nvfuser/csrc/transform_view.cpp +++ b/third_party/nvfuser/csrc/transform_view.cpp @@ -671,7 +671,7 @@ TensorDomain* createViewDomain( new_root_domain, new_rfactor_domain, new_rfactor_domain, - TensorDomain::getContiguousContiguity(new_rfactor_domain)); + TensorDomain::getContiguityFilledWith(new_rfactor_domain, true)); } } // namespace diff --git a/third_party/nvfuser/test/test_gpu1.cpp b/third_party/nvfuser/test/test_gpu1.cpp index d240f35be3a4..5a0993906008 100644 --- a/third_party/nvfuser/test/test_gpu1.cpp +++ b/third_party/nvfuser/test/test_gpu1.cpp @@ -881,9 +881,9 @@ TEST_F(NVFuserTest, FusionTensor_CUDA) { // size 1 dimension are makred as broadcast TORCH_CHECK(fuser_tensor->axis(i)->isBroadcast() == false); } - TORCH_CHECK(fuser_tensor->domain()->contiguity()[0]); - TORCH_CHECK(!fuser_tensor->domain()->contiguity()[1]); - TORCH_CHECK(fuser_tensor->domain()->contiguity()[2]); + TORCH_CHECK(*fuser_tensor->domain()->contiguity()[0]); + TORCH_CHECK(!*fuser_tensor->domain()->contiguity()[1]); + TORCH_CHECK(*fuser_tensor->domain()->contiguity()[2]); } { @@ -898,10 +898,10 @@ TEST_F(NVFuserTest, FusionTensor_CUDA) { // size 1 dimension are makred as broadcast TORCH_CHECK(fuser_tensor->axis(i)->isBroadcast() == false); } - TORCH_CHECK(!fuser_tensor->domain()->contiguity()[0]); - TORCH_CHECK(!fuser_tensor->domain()->contiguity()[1]); - TORCH_CHECK(fuser_tensor->domain()->contiguity()[2]); - TORCH_CHECK(!fuser_tensor->domain()->contiguity()[3]); + TORCH_CHECK(!*fuser_tensor->domain()->contiguity()[0]); + TORCH_CHECK(!*fuser_tensor->domain()->contiguity()[1]); + TORCH_CHECK(*fuser_tensor->domain()->contiguity()[2]); + TORCH_CHECK(!*fuser_tensor->domain()->contiguity()[3]); } } diff --git a/third_party/nvfuser/test/test_gpu3.cpp b/third_party/nvfuser/test/test_gpu3.cpp index 518441f6aad7..678d320b834a 100644 --- a/third_party/nvfuser/test/test_gpu3.cpp +++ b/third_party/nvfuser/test/test_gpu3.cpp @@ -1291,7 +1291,7 @@ TEST_F(NVFuserTest, FusionIssue1430_CUDA) { auto tv0 = TensorViewBuilder() .ndims(5) .dtype(DataType::Half) - .contiguity(std::vector(5, true)) + .contiguity(true) .shape({V, W, X, Y, Z}) .build(); @@ -4089,7 +4089,7 @@ TEST_F(NVFuserTest, FusionReproNoncontigBroadcast_CUDA) { .build(); auto tv1 = TensorViewBuilder() .ndims(4) - .contiguity({true, true}) + .contiguity({true, c10::nullopt, c10::nullopt, true}) .shape({-1, 1, 1, -1}) .dtype(DataType::Half) .build(); @@ -4702,7 +4702,7 @@ TEST_F(NVFuserTest, FusionExpandRepro1860_CUDA) { auto fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr; FusionGuard fg(&fusion); - std::vector contiguity{}; + std::vector> contiguity(3, c10::nullopt); std::vector shape{1, -1, -1}; TensorView* tv0 = makeContigConcreteTensor(shape); @@ -4861,7 +4861,7 @@ TEST_F(NVFuserTest, FusionExpandBadShapeTest_CUDA) { auto fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr; FusionGuard fg(&fusion); - std::vector contiguity{false}; + std::vector> contiguity{false, c10::nullopt}; auto tv0 = makeSymbolicTensor(2); fusion.addInput(tv0); @@ -5998,7 +5998,7 @@ TEST_F(NVFuserTest, FusionExpandedInput_CUDA) { TensorView* tv0 = TensorViewBuilder() .ndims(3) .shape({-1, -1, -1}) - .contiguity({false, true}) + .contiguity({false, c10::nullopt, true}) .expanded({false, true, false}) .build(); fusion->addInput(tv0); @@ -6862,7 +6862,8 @@ TEST_F(NVFuserTest, FusionSqueezeOnlyWelford_CUDA) { auto dim0 = IterDomainBuilder(w1.avg->axis(0)).build(); auto dim1 = IterDomainBuilder(w1.avg->axis(1)).build(); auto td = IrBuilder::create( - std::vector{dim0, dim1}, std::vector{true, true}); + std::vector{dim0, dim1}, + std::vector>{true, true}); auto tv = IrBuilder::create(td, dtype); return tv; }; diff --git a/third_party/nvfuser/test/test_gpu_utils.cpp b/third_party/nvfuser/test/test_gpu_utils.cpp index d67d3dab2fdf..4c8ca6f60d64 100644 --- a/third_party/nvfuser/test/test_gpu_utils.cpp +++ b/third_party/nvfuser/test/test_gpu_utils.cpp @@ -217,9 +217,9 @@ TEST_F(NVFuserTest, FusionTVDomainGuard_CUDA) { Fusion fusion; FusionGuard fg(&fusion); - std::vector all_true = {true, true}; - std::vector all_false = {false, false}; - std::vector false_true = {false, true}; + std::vector> all_true = {true, true}; + std::vector> all_false = {false, false}; + std::vector> false_true = {false, true}; auto tv = TensorViewBuilder().ndims(2).contiguity(false_true).build(); TORCH_CHECK(tv->domain()->contiguity() == false_true); { diff --git a/third_party/nvfuser/test/test_gpu_view.cpp b/third_party/nvfuser/test/test_gpu_view.cpp index d1484f9cd6ec..61321acbbfca 100644 --- a/third_party/nvfuser/test/test_gpu_view.cpp +++ b/third_party/nvfuser/test/test_gpu_view.cpp @@ -2142,7 +2142,7 @@ TEST_F(NVFuserTest, FusionIssue2076_CUDA) { auto tv0 = TensorViewBuilder() .shape({-1, 1, -1, -1}) .dtype(DataType::Bool) - .contiguity({true, true, true}) + .contiguity({true, c10::nullopt, true, true}) .build(); fusion.addInput(tv0); @@ -2247,8 +2247,10 @@ TEST_F(NVFuserTest, FusionIssue2076_v2_CUDA) { // torch.randn(4, 128, 1, device='cuda').transpose(1,2) // sizes[4, 128, 1] strides[128, 1, 1] // sizes[4, 1, 128] strides[128, 128, 1] - auto tv0 = - TensorViewBuilder().shape({-1, 1, -1}).contiguity({true, true}).build(); + auto tv0 = TensorViewBuilder() + .shape({-1, 1, -1}) + .contiguity({true, c10::nullopt, true}) + .build(); fusion.addInput(tv0); // torch.randn(48, 128, device='cuda') diff --git a/third_party/nvfuser/test/test_multicluster_fusion.cpp b/third_party/nvfuser/test/test_multicluster_fusion.cpp index 5449b15bb335..3a889a674aaa 100644 --- a/third_party/nvfuser/test/test_multicluster_fusion.cpp +++ b/third_party/nvfuser/test/test_multicluster_fusion.cpp @@ -111,19 +111,19 @@ TEST_F(NVFuserTest, MultiClusterFusion) { "AggregateDag's Traversal inputs --> outputs {\n" " AggregateExpr representing Cluster 0.Inputs={T0_g[ iS0{i0}, iS1{i2}, iS2{i3} ], }. Outputs={T1_l[ iS3{i0}, iS4{i2}, iS5{i3} ], }.\n" " AggregateVal representing Val T1_l[ iS3{i0}, iS4{i2}, iS5{i3} ] on cluster 0\n" - " Send/Receive Val {T1_l[ iS3{i0}, iS4{i2}, iS5{i3} ]} from cluster 0 to cluster 2\n" - " AggregateVal representing Val T1_l[ iS3{i0}, iS4{i2}, iS5{i3} ] on cluster 2\n" - " AggregateExpr representing Cluster 2.Inputs={T1_l[ iS3{i0}, iS4{i2}, iS5{i3} ], }. Outputs={T5_l[ rS12{i2}, iS13{i3} ], }.\n" - " AggregateVal representing Val T5_l[ rS12{i2}, iS13{i3} ] on cluster 2\n" - " Send/Receive Val {T5_l[ rS12{i2}, iS13{i3} ]} from cluster 2 to cluster 3\n" - " AggregateVal representing Val T5_l[ rS12{i2}, iS13{i3} ] on cluster 3\n" " Send/Receive Val {T1_l[ iS3{i0}, iS4{i2}, iS5{i3} ]} from cluster 0 to cluster 1\n" " AggregateVal representing Val T1_l[ iS3{i0}, iS4{i2}, iS5{i3} ] on cluster 1\n" " AggregateExpr representing Cluster 1.Inputs={T1_l[ iS3{i0}, iS4{i2}, iS5{i3} ], }. Outputs={T3_l[ rS8{i2}, iS9{i3} ], }.\n" " AggregateVal representing Val T3_l[ rS8{i2}, iS9{i3} ] on cluster 1\n" " Send/Receive Val {T3_l[ rS8{i2}, iS9{i3} ]} from cluster 1 to cluster 3\n" " AggregateVal representing Val T3_l[ rS8{i2}, iS9{i3} ] on cluster 3\n" - " AggregateExpr representing Cluster 3.Inputs={T5_l[ rS12{i2}, iS13{i3} ], T3_l[ rS8{i2}, iS9{i3} ], }. Outputs={T6_g[ iS14{i3} ], }.\n" + " Send/Receive Val {T1_l[ iS3{i0}, iS4{i2}, iS5{i3} ]} from cluster 0 to cluster 2\n" + " AggregateVal representing Val T1_l[ iS3{i0}, iS4{i2}, iS5{i3} ] on cluster 2\n" + " AggregateExpr representing Cluster 2.Inputs={T1_l[ iS3{i0}, iS4{i2}, iS5{i3} ], }. Outputs={T5_l[ rS12{i2}, iS13{i3} ], }.\n" + " AggregateVal representing Val T5_l[ rS12{i2}, iS13{i3} ] on cluster 2\n" + " Send/Receive Val {T5_l[ rS12{i2}, iS13{i3} ]} from cluster 2 to cluster 3\n" + " AggregateVal representing Val T5_l[ rS12{i2}, iS13{i3} ] on cluster 3\n" + " AggregateExpr representing Cluster 3.Inputs={T3_l[ rS8{i2}, iS9{i3} ], T5_l[ rS12{i2}, iS13{i3} ], }. Outputs={T6_g[ iS14{i3} ], }.\n" "}\n" "AggregateDag's outputs:{\n" " AggregateVal representing Val T6_g[ iS14{i3} ] on cluster 3\n" diff --git a/third_party/nvfuser/test/test_utils.h b/third_party/nvfuser/test/test_utils.h index a757c7925eb4..2c1ce8f9d5a3 100644 --- a/third_party/nvfuser/test/test_utils.h +++ b/third_party/nvfuser/test/test_utils.h @@ -28,11 +28,7 @@ namespace nvfuser { inline TensorView* makeContigTensor( size_t ndims, DataType dtype = DataType::Float) { - return TensorViewBuilder() - .ndims(ndims) - .dtype(dtype) - .contiguity(std::vector(ndims, true)) - .build(); + return TensorViewBuilder().ndims(ndims).dtype(dtype).contiguity(true).build(); } // Make a tensor that is known to be non-contiguous of dimensionality=ndims, @@ -53,18 +49,7 @@ inline TensorView* makeConcreteTensor( inline TensorView* makeContigConcreteTensor( std::vector shape, DataType dtype = DataType::Float) { - std::vector contiguity; - for (auto s : shape) { - if (s == 1) { - continue; - } - contiguity.push_back(true); - } - return TensorViewBuilder() - .shape(shape) - .dtype(dtype) - .contiguity(contiguity) - .build(); + return TensorViewBuilder().shape(shape).dtype(dtype).contiguity(true).build(); } inline void checkIntValue(