diff --git a/torch/csrc/jit/codegen/cuda/root_domain_map.cpp b/torch/csrc/jit/codegen/cuda/root_domain_map.cpp index 06a6d5ec1d17..fdf6f15b24e0 100644 --- a/torch/csrc/jit/codegen/cuda/root_domain_map.cpp +++ b/torch/csrc/jit/codegen/cuda/root_domain_map.cpp @@ -234,12 +234,12 @@ class FindInputDomains : BackwardVisitor { private: FindInputDomains(TensorView* tv, const IterDomain* id) : BackwardVisitor(false), tv_(tv) { - input_keys.insert(DomainKey(tv_->domain(), id)); + input_keys_.insert(DomainKey(tv_->domain(), id)); } DomainKeySet find() { traverseFrom(tv_->fusion(), {tv_}); - return input_keys; + return input_keys_; } void handle(Expr* expr) override { @@ -261,7 +261,7 @@ class FindInputDomains : BackwardVisitor { .mapConsumerToProducer(out_tv->domain(), in_tv->domain()); for (auto root_dom : out_tv->getRootDomain()) { DomainKey out_key({out_tv->domain(), root_dom}); - if (input_keys.find(out_key) == input_keys.end()) { + if (input_keys_.find(out_key) == input_keys_.end()) { continue; } auto input_id_it = c2p.find(root_dom); @@ -269,13 +269,13 @@ class FindInputDomains : BackwardVisitor { continue; } DomainKey input_key(in_tv->domain(), input_id_it->second); - input_keys.insert(input_key); + input_keys_.insert(input_key); } } private: TensorView* tv_ = nullptr; - DomainKeySet input_keys; + DomainKeySet input_keys_; public: static DomainKeySet find(TensorView* tv, const IterDomain* id) { @@ -297,6 +297,10 @@ void UnmappableReductionDomains::handleReductionOutput(TensorView* out_tv) { auto use_chains = DependencyCheck::getAllUseChains(out_tv); for (const auto& chain : use_chains) { for (const auto& tv : ir_utils::filterByType(chain)) { + // Do not include the tensor itself in its consumers + if (tv == out_tv) { + continue; + } const auto& root_domain = tv->getRootDomain(); for (const auto& id : root_domain) { DomainKey consumer_key(tv->domain(), id); @@ -339,30 +343,41 @@ void UnmappableReductionDomains::handle(WelfordOp* op) { } bool UnmappableReductionDomains::isReductionOutputMapped( - const std::vector& consumer_domains, + const DomainKeySet& consumer_domains, const ComputeAtRootDomainMap& root_map) const { + // Check each reduction domain if any of the consumer domains + // conflicts with it for (const auto& kv : reduction_domains_) { const DomainKey& reduction_domain = kv.first; + // Domains that must not be mapped with the reduction domain const DomainKeySet& incompatible_domains = kv.second; - DomainKey consumer_domain_with_reduction; - bool reduction_found = false; + // Input domains to the reduction domain const auto& input_keys = reduction_domain_inputs_.at(reduction_domain); - for (const DomainKey& consumer_domain : consumer_domains) { - for (const auto& input_key : input_keys) { - if (input_key == consumer_domain) { - consumer_domain_with_reduction = consumer_domain; - reduction_found = true; - break; - } - } - } - if (!reduction_found) { + // Check if any of the consumer domains is an input to the + // reduction + auto it = std::find_if( + consumer_domains.begin(), + consumer_domains.end(), + [&](const auto& consumer_domain) { + return std::find( + input_keys.begin(), input_keys.end(), consumer_domain) != + input_keys.end(); + }); + // None of the consumer domains is used for the reduction + // domain. They should be safe with respect to this reduction + // domain + if (it == consumer_domains.end()) { continue; } - // Make sure no incompatible domains will be merged with the reduction - // domain. + + // A consumer domain that is an input to the reduction domain + const DomainKey& input_to_reduction = *it; + + // Check if mapping input_to_reduction with the other domains in + // consumer_domains. If there's a domain that is a consumer of the + // reduction, they must not be mapped together for (const auto& consumer_domain : consumer_domains) { - if (consumer_domain == consumer_domain_with_reduction) { + if (consumer_domain == input_to_reduction) { continue; } if (std::any_of( @@ -382,6 +397,27 @@ bool UnmappableReductionDomains::isReductionOutputMapped( return false; } +std::string UnmappableReductionDomains::toString() const { + std::stringstream ss; + ss << "Reduction-to-consumer map\n"; + for (const auto& kv : reduction_domains_) { + ss << "\tReduction: " << kv.first.toString() << "\n"; + for (const auto& mapped_val : kv.second) { + ss << "\t\tConsumer domain: " << mapped_val.toString() << "\n"; + } + } + + ss << "Reduction-to-producer map\n"; + for (const auto& kv : reduction_domain_inputs_) { + ss << "\tReduction: " << kv.first.toString() << "\n"; + for (const auto& mapped_val : kv.second) { + ss << "\t\tProducer domain: " << mapped_val.toString() << "\n"; + } + } + + return ss.str(); +} + void ComputeAtRootDomainMap::build(bool map_through_reduction) { // Make sure we start from scratch. Throw away previous results. eq_set_.clear(); @@ -724,7 +760,7 @@ void ComputeAtRootDomainMapBuilder::setInvalid( } bool ComputeAtRootDomainMapBuilder::isInvalid( - const std::vector& domains) const { + const DomainKeySet& domains) const { // First, collect all invalid mappings for each of the keys in domains DomainKeyMap invalid_key_map; for (const auto& key : domains) { @@ -741,8 +777,9 @@ bool ComputeAtRootDomainMapBuilder::isInvalid( // Next, check if any pair is invalid to map. const auto num_keys = domains.size(); + const std::vector domains_vec({domains.begin(), domains.end()}); for (const auto i : c10::irange(num_keys)) { - const auto& key_i = domains[i]; + const auto& key_i = domains_vec[i]; // If no invalid keys found for key_i, it can be skipped. const auto invalid_key_map_it = invalid_key_map.find(key_i); if (invalid_key_map_it == invalid_key_map.end()) { @@ -755,7 +792,7 @@ bool ComputeAtRootDomainMapBuilder::isInvalid( // If any other key in domains is identified mappable with any of // the keys in this set, the mapping with key_i is invalid. for (const auto j : c10::irange(i + 1, num_keys)) { - const auto& key_j = domains[j]; + const auto& key_j = domains_vec[j]; if (std::any_of( invalid_keys_for_i.begin(), invalid_keys_for_i.end(), @@ -1070,26 +1107,14 @@ bool ComputeAtRootDomainMapBuilder::safeToMap(const DomainKeySet& domains) { if (domains.size() <= 1) { return true; } - // Filter out equivalent domains - std::vector unique_domains; - for (const auto& domain : domains) { - if (std::none_of( - unique_domains.begin(), - unique_domains.end(), - [&](const auto& unique_dom) { - return root_map_.canMap(domain, unique_dom); - })) { - unique_domains.push_back(domain); - } - } + // Can't map if reduction output domains would be mapped - if (incompatible_domains_.isReductionOutputMapped( - unique_domains, root_map_) && + if (incompatible_domains_.isReductionOutputMapped(domains, root_map_) && !map_through_reduction_) { return false; } // Make sure mapping these domains won't cause any invalid mapping - if (isInvalid(unique_domains)) { + if (isInvalid(domains)) { return false; } return true; diff --git a/torch/csrc/jit/codegen/cuda/root_domain_map.h b/torch/csrc/jit/codegen/cuda/root_domain_map.h index cfbd1ee8eba5..fa3d323ba6d2 100644 --- a/torch/csrc/jit/codegen/cuda/root_domain_map.h +++ b/torch/csrc/jit/codegen/cuda/root_domain_map.h @@ -145,6 +145,9 @@ class DomainKey { return td() == other.td() && id() == other.id() && concreteId() == other.concreteId(); } + bool operator!=(const DomainKey& other) const { + return !(*this == other); + } std::string toString() const; @@ -183,9 +186,11 @@ class TORCH_CUDA_CU_API UnmappableReductionDomains : private IterVisitor { //! reduction outputs within the corresponding reduction loop is not //! possible. This routine is used to build root domain mappings. bool isReductionOutputMapped( - const std::vector& consumer_domains, + const DomainKeySet& consumer_domains, const ComputeAtRootDomainMap& root_map) const; + std::string toString() const; + private: using IterVisitor::handle; void handle(ReductionOp* op) override; @@ -365,7 +370,7 @@ class TORCH_CUDA_CU_API ComputeAtRootDomainMapBuilder void setInvalid(const DomainKey& key1, const DomainKey& key2); //! Check if no pair of domains is invalid to map - bool isInvalid(const std::vector& domains) const; + bool isInvalid(const DomainKeySet& domains) const; //! Track a pair of producer-consumer domains as potentially mappable. Inserts //! entries into pending_map_, but does not add anything into the root_map_ diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 5ad6126760d7..be190bbb520e 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -3741,6 +3741,38 @@ TEST_F(NVFuserTest, FusionRootMappingTrivialReduction_CUDA) { testValidate(&fusion, outputs, aten_inputs, {t3, t4}, __LINE__, __FILE__); } +// Repro of issue #1950 +TEST_F(NVFuserTest, FusionRootMappingRepro1950_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeSymbolicTensor(3); + auto tv1 = makeSymbolicTensor(3); + auto tv2 = makeSymbolicTensor(3); + + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addInput(tv2); + + auto tv3 = set(tv0); + auto tv4 = mul(tv1, tv3); + auto tv5 = mul(tv1, tv2); + auto tv6 = mul(tv5, tv3); + auto tv7 = sum(tv6, {2}); + auto tv8 = broadcast(tv7, {false, false, true}); + auto tv9 = mul(tv3, tv8); + + // Issue #1950 was caused by a particular traversal ordering based + // on the output tensor ordering as below + fusion.addOutput(tv9); + fusion.addOutput(tv5); + fusion.addOutput(tv4); + + ComputeAtRootDomainMap root_map; + root_map.build(); + + checkIdMapped(root_map, tv4, tv4->axis(-1), tv9, tv9->axis(-1), false); +} + TEST_F(NVFuserTest, FusionDetectSelfMappedDomains_CUDA) { Fusion fusion; FusionGuard fg(&fusion);