From ad1b06b04c9e367a0cefbb738a4505708669a1d1 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 1 Sep 2022 17:41:29 -0700 Subject: [PATCH] Fix detection of unmappable root domains ComputeAtRootDomainMap flags domains that should not be mapped due to reductions. Previously, checking if a domain potentially causes an invalid mapping is only done with one domain in each group of domains that are found to be mappable so far. That's not actually sufficient as the unmappable domain set is created just once with no root mapping information. The fix is to check all consumer domains of a producer tensor. A small other fix is also done to address a different problem discovered after the first fix. --- .../csrc/jit/codegen/cuda/root_domain_map.cpp | 103 +++++++++++------- torch/csrc/jit/codegen/cuda/root_domain_map.h | 9 +- torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 32 ++++++ 3 files changed, 103 insertions(+), 41 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/root_domain_map.cpp b/torch/csrc/jit/codegen/cuda/root_domain_map.cpp index 06a6d5ec1d17..fdf6f15b24e0 100644 --- a/torch/csrc/jit/codegen/cuda/root_domain_map.cpp +++ b/torch/csrc/jit/codegen/cuda/root_domain_map.cpp @@ -234,12 +234,12 @@ class FindInputDomains : BackwardVisitor { private: FindInputDomains(TensorView* tv, const IterDomain* id) : BackwardVisitor(false), tv_(tv) { - input_keys.insert(DomainKey(tv_->domain(), id)); + input_keys_.insert(DomainKey(tv_->domain(), id)); } DomainKeySet find() { traverseFrom(tv_->fusion(), {tv_}); - return input_keys; + return input_keys_; } void handle(Expr* expr) override { @@ -261,7 +261,7 @@ class FindInputDomains : BackwardVisitor { .mapConsumerToProducer(out_tv->domain(), in_tv->domain()); for (auto root_dom : out_tv->getRootDomain()) { DomainKey out_key({out_tv->domain(), root_dom}); - if (input_keys.find(out_key) == input_keys.end()) { + if (input_keys_.find(out_key) == input_keys_.end()) { continue; } auto input_id_it = c2p.find(root_dom); @@ -269,13 +269,13 @@ class FindInputDomains : BackwardVisitor { continue; } DomainKey input_key(in_tv->domain(), input_id_it->second); - input_keys.insert(input_key); + input_keys_.insert(input_key); } } private: TensorView* tv_ = nullptr; - DomainKeySet input_keys; + DomainKeySet input_keys_; public: static DomainKeySet find(TensorView* tv, const IterDomain* id) { @@ -297,6 +297,10 @@ void UnmappableReductionDomains::handleReductionOutput(TensorView* out_tv) { auto use_chains = DependencyCheck::getAllUseChains(out_tv); for (const auto& chain : use_chains) { for (const auto& tv : ir_utils::filterByType(chain)) { + // Do not include the tensor itself in its consumers + if (tv == out_tv) { + continue; + } const auto& root_domain = tv->getRootDomain(); for (const auto& id : root_domain) { DomainKey consumer_key(tv->domain(), id); @@ -339,30 +343,41 @@ void UnmappableReductionDomains::handle(WelfordOp* op) { } bool UnmappableReductionDomains::isReductionOutputMapped( - const std::vector& consumer_domains, + const DomainKeySet& consumer_domains, const ComputeAtRootDomainMap& root_map) const { + // Check each reduction domain if any of the consumer domains + // conflicts with it for (const auto& kv : reduction_domains_) { const DomainKey& reduction_domain = kv.first; + // Domains that must not be mapped with the reduction domain const DomainKeySet& incompatible_domains = kv.second; - DomainKey consumer_domain_with_reduction; - bool reduction_found = false; + // Input domains to the reduction domain const auto& input_keys = reduction_domain_inputs_.at(reduction_domain); - for (const DomainKey& consumer_domain : consumer_domains) { - for (const auto& input_key : input_keys) { - if (input_key == consumer_domain) { - consumer_domain_with_reduction = consumer_domain; - reduction_found = true; - break; - } - } - } - if (!reduction_found) { + // Check if any of the consumer domains is an input to the + // reduction + auto it = std::find_if( + consumer_domains.begin(), + consumer_domains.end(), + [&](const auto& consumer_domain) { + return std::find( + input_keys.begin(), input_keys.end(), consumer_domain) != + input_keys.end(); + }); + // None of the consumer domains is used for the reduction + // domain. They should be safe with respect to this reduction + // domain + if (it == consumer_domains.end()) { continue; } - // Make sure no incompatible domains will be merged with the reduction - // domain. + + // A consumer domain that is an input to the reduction domain + const DomainKey& input_to_reduction = *it; + + // Check if mapping input_to_reduction with the other domains in + // consumer_domains. If there's a domain that is a consumer of the + // reduction, they must not be mapped together for (const auto& consumer_domain : consumer_domains) { - if (consumer_domain == consumer_domain_with_reduction) { + if (consumer_domain == input_to_reduction) { continue; } if (std::any_of( @@ -382,6 +397,27 @@ bool UnmappableReductionDomains::isReductionOutputMapped( return false; } +std::string UnmappableReductionDomains::toString() const { + std::stringstream ss; + ss << "Reduction-to-consumer map\n"; + for (const auto& kv : reduction_domains_) { + ss << "\tReduction: " << kv.first.toString() << "\n"; + for (const auto& mapped_val : kv.second) { + ss << "\t\tConsumer domain: " << mapped_val.toString() << "\n"; + } + } + + ss << "Reduction-to-producer map\n"; + for (const auto& kv : reduction_domain_inputs_) { + ss << "\tReduction: " << kv.first.toString() << "\n"; + for (const auto& mapped_val : kv.second) { + ss << "\t\tProducer domain: " << mapped_val.toString() << "\n"; + } + } + + return ss.str(); +} + void ComputeAtRootDomainMap::build(bool map_through_reduction) { // Make sure we start from scratch. Throw away previous results. eq_set_.clear(); @@ -724,7 +760,7 @@ void ComputeAtRootDomainMapBuilder::setInvalid( } bool ComputeAtRootDomainMapBuilder::isInvalid( - const std::vector& domains) const { + const DomainKeySet& domains) const { // First, collect all invalid mappings for each of the keys in domains DomainKeyMap invalid_key_map; for (const auto& key : domains) { @@ -741,8 +777,9 @@ bool ComputeAtRootDomainMapBuilder::isInvalid( // Next, check if any pair is invalid to map. const auto num_keys = domains.size(); + const std::vector domains_vec({domains.begin(), domains.end()}); for (const auto i : c10::irange(num_keys)) { - const auto& key_i = domains[i]; + const auto& key_i = domains_vec[i]; // If no invalid keys found for key_i, it can be skipped. const auto invalid_key_map_it = invalid_key_map.find(key_i); if (invalid_key_map_it == invalid_key_map.end()) { @@ -755,7 +792,7 @@ bool ComputeAtRootDomainMapBuilder::isInvalid( // If any other key in domains is identified mappable with any of // the keys in this set, the mapping with key_i is invalid. for (const auto j : c10::irange(i + 1, num_keys)) { - const auto& key_j = domains[j]; + const auto& key_j = domains_vec[j]; if (std::any_of( invalid_keys_for_i.begin(), invalid_keys_for_i.end(), @@ -1070,26 +1107,14 @@ bool ComputeAtRootDomainMapBuilder::safeToMap(const DomainKeySet& domains) { if (domains.size() <= 1) { return true; } - // Filter out equivalent domains - std::vector unique_domains; - for (const auto& domain : domains) { - if (std::none_of( - unique_domains.begin(), - unique_domains.end(), - [&](const auto& unique_dom) { - return root_map_.canMap(domain, unique_dom); - })) { - unique_domains.push_back(domain); - } - } + // Can't map if reduction output domains would be mapped - if (incompatible_domains_.isReductionOutputMapped( - unique_domains, root_map_) && + if (incompatible_domains_.isReductionOutputMapped(domains, root_map_) && !map_through_reduction_) { return false; } // Make sure mapping these domains won't cause any invalid mapping - if (isInvalid(unique_domains)) { + if (isInvalid(domains)) { return false; } return true; diff --git a/torch/csrc/jit/codegen/cuda/root_domain_map.h b/torch/csrc/jit/codegen/cuda/root_domain_map.h index cfbd1ee8eba5..fa3d323ba6d2 100644 --- a/torch/csrc/jit/codegen/cuda/root_domain_map.h +++ b/torch/csrc/jit/codegen/cuda/root_domain_map.h @@ -145,6 +145,9 @@ class DomainKey { return td() == other.td() && id() == other.id() && concreteId() == other.concreteId(); } + bool operator!=(const DomainKey& other) const { + return !(*this == other); + } std::string toString() const; @@ -183,9 +186,11 @@ class TORCH_CUDA_CU_API UnmappableReductionDomains : private IterVisitor { //! reduction outputs within the corresponding reduction loop is not //! possible. This routine is used to build root domain mappings. bool isReductionOutputMapped( - const std::vector& consumer_domains, + const DomainKeySet& consumer_domains, const ComputeAtRootDomainMap& root_map) const; + std::string toString() const; + private: using IterVisitor::handle; void handle(ReductionOp* op) override; @@ -365,7 +370,7 @@ class TORCH_CUDA_CU_API ComputeAtRootDomainMapBuilder void setInvalid(const DomainKey& key1, const DomainKey& key2); //! Check if no pair of domains is invalid to map - bool isInvalid(const std::vector& domains) const; + bool isInvalid(const DomainKeySet& domains) const; //! Track a pair of producer-consumer domains as potentially mappable. Inserts //! entries into pending_map_, but does not add anything into the root_map_ diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 5ad6126760d7..be190bbb520e 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -3741,6 +3741,38 @@ TEST_F(NVFuserTest, FusionRootMappingTrivialReduction_CUDA) { testValidate(&fusion, outputs, aten_inputs, {t3, t4}, __LINE__, __FILE__); } +// Repro of issue #1950 +TEST_F(NVFuserTest, FusionRootMappingRepro1950_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeSymbolicTensor(3); + auto tv1 = makeSymbolicTensor(3); + auto tv2 = makeSymbolicTensor(3); + + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addInput(tv2); + + auto tv3 = set(tv0); + auto tv4 = mul(tv1, tv3); + auto tv5 = mul(tv1, tv2); + auto tv6 = mul(tv5, tv3); + auto tv7 = sum(tv6, {2}); + auto tv8 = broadcast(tv7, {false, false, true}); + auto tv9 = mul(tv3, tv8); + + // Issue #1950 was caused by a particular traversal ordering based + // on the output tensor ordering as below + fusion.addOutput(tv9); + fusion.addOutput(tv5); + fusion.addOutput(tv4); + + ComputeAtRootDomainMap root_map; + root_map.build(); + + checkIdMapped(root_map, tv4, tv4->axis(-1), tv9, tv9->axis(-1), false); +} + TEST_F(NVFuserTest, FusionDetectSelfMappedDomains_CUDA) { Fusion fusion; FusionGuard fg(&fusion);