forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 7
Fix detection of unmappable root domains #1952
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -234,12 +234,12 @@ class FindInputDomains : BackwardVisitor { | |
private: | ||
FindInputDomains(TensorView* tv, const IterDomain* id) | ||
: BackwardVisitor(false), tv_(tv) { | ||
input_keys.insert(DomainKey(tv_->domain(), id)); | ||
input_keys_.insert(DomainKey(tv_->domain(), id)); | ||
} | ||
|
||
DomainKeySet find() { | ||
traverseFrom(tv_->fusion(), {tv_}); | ||
return input_keys; | ||
return input_keys_; | ||
} | ||
|
||
void handle(Expr* expr) override { | ||
|
@@ -261,21 +261,21 @@ class FindInputDomains : BackwardVisitor { | |
.mapConsumerToProducer(out_tv->domain(), in_tv->domain()); | ||
for (auto root_dom : out_tv->getRootDomain()) { | ||
DomainKey out_key({out_tv->domain(), root_dom}); | ||
if (input_keys.find(out_key) == input_keys.end()) { | ||
if (input_keys_.find(out_key) == input_keys_.end()) { | ||
continue; | ||
} | ||
auto input_id_it = c2p.find(root_dom); | ||
if (input_id_it == c2p.end()) { | ||
continue; | ||
} | ||
DomainKey input_key(in_tv->domain(), input_id_it->second); | ||
input_keys.insert(input_key); | ||
input_keys_.insert(input_key); | ||
} | ||
} | ||
|
||
private: | ||
TensorView* tv_ = nullptr; | ||
DomainKeySet input_keys; | ||
DomainKeySet input_keys_; | ||
|
||
public: | ||
static DomainKeySet find(TensorView* tv, const IterDomain* id) { | ||
|
@@ -297,6 +297,10 @@ void UnmappableReductionDomains::handleReductionOutput(TensorView* out_tv) { | |
auto use_chains = DependencyCheck::getAllUseChains(out_tv); | ||
for (const auto& chain : use_chains) { | ||
for (const auto& tv : ir_utils::filterByType<TensorView>(chain)) { | ||
// Do not include the tensor itself in its consumers | ||
if (tv == out_tv) { | ||
continue; | ||
} | ||
const auto& root_domain = tv->getRootDomain(); | ||
for (const auto& id : root_domain) { | ||
DomainKey consumer_key(tv->domain(), id); | ||
|
@@ -339,30 +343,41 @@ void UnmappableReductionDomains::handle(WelfordOp* op) { | |
} | ||
|
||
bool UnmappableReductionDomains::isReductionOutputMapped( | ||
const std::vector<DomainKey>& consumer_domains, | ||
const DomainKeySet& consumer_domains, | ||
const ComputeAtRootDomainMap& root_map) const { | ||
// Check each reduction domain if any of the consumer domains | ||
// conflicts with it | ||
for (const auto& kv : reduction_domains_) { | ||
const DomainKey& reduction_domain = kv.first; | ||
// Domains that must not be mapped with the reduction domain | ||
const DomainKeySet& incompatible_domains = kv.second; | ||
DomainKey consumer_domain_with_reduction; | ||
bool reduction_found = false; | ||
// Input domains to the reduction domain | ||
const auto& input_keys = reduction_domain_inputs_.at(reduction_domain); | ||
for (const DomainKey& consumer_domain : consumer_domains) { | ||
for (const auto& input_key : input_keys) { | ||
if (input_key == consumer_domain) { | ||
consumer_domain_with_reduction = consumer_domain; | ||
reduction_found = true; | ||
break; | ||
} | ||
} | ||
} | ||
if (!reduction_found) { | ||
// Check if any of the consumer domains is an input to the | ||
// reduction | ||
auto it = std::find_if( | ||
consumer_domains.begin(), | ||
consumer_domains.end(), | ||
[&](const auto& consumer_domain) { | ||
return std::find( | ||
input_keys.begin(), input_keys.end(), consumer_domain) != | ||
input_keys.end(); | ||
}); | ||
// None of the consumer domains is used for the reduction | ||
// domain. They should be safe with respect to this reduction | ||
// domain | ||
if (it == consumer_domains.end()) { | ||
continue; | ||
} | ||
// Make sure no incompatible domains will be merged with the reduction | ||
// domain. | ||
|
||
// A consumer domain that is an input to the reduction domain | ||
const DomainKey& input_to_reduction = *it; | ||
|
||
// Check if mapping input_to_reduction with the other domains in | ||
// consumer_domains. If there's a domain that is a consumer of the | ||
// reduction, they must not be mapped together | ||
for (const auto& consumer_domain : consumer_domains) { | ||
if (consumer_domain == consumer_domain_with_reduction) { | ||
if (consumer_domain == input_to_reduction) { | ||
Comment on lines
+346
to
+380
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just additional comments and cleanup |
||
continue; | ||
} | ||
if (std::any_of( | ||
|
@@ -382,6 +397,27 @@ bool UnmappableReductionDomains::isReductionOutputMapped( | |
return false; | ||
} | ||
|
||
std::string UnmappableReductionDomains::toString() const { | ||
std::stringstream ss; | ||
ss << "Reduction-to-consumer map\n"; | ||
for (const auto& kv : reduction_domains_) { | ||
ss << "\tReduction: " << kv.first.toString() << "\n"; | ||
for (const auto& mapped_val : kv.second) { | ||
ss << "\t\tConsumer domain: " << mapped_val.toString() << "\n"; | ||
} | ||
} | ||
|
||
ss << "Reduction-to-producer map\n"; | ||
for (const auto& kv : reduction_domain_inputs_) { | ||
ss << "\tReduction: " << kv.first.toString() << "\n"; | ||
for (const auto& mapped_val : kv.second) { | ||
ss << "\t\tProducer domain: " << mapped_val.toString() << "\n"; | ||
} | ||
} | ||
|
||
return ss.str(); | ||
} | ||
|
||
void ComputeAtRootDomainMap::build(bool map_through_reduction) { | ||
// Make sure we start from scratch. Throw away previous results. | ||
eq_set_.clear(); | ||
|
@@ -724,7 +760,7 @@ void ComputeAtRootDomainMapBuilder::setInvalid( | |
} | ||
|
||
bool ComputeAtRootDomainMapBuilder::isInvalid( | ||
const std::vector<DomainKey>& domains) const { | ||
const DomainKeySet& domains) const { | ||
// First, collect all invalid mappings for each of the keys in domains | ||
DomainKeyMap<DomainKeySet> invalid_key_map; | ||
for (const auto& key : domains) { | ||
|
@@ -741,8 +777,9 @@ bool ComputeAtRootDomainMapBuilder::isInvalid( | |
|
||
// Next, check if any pair is invalid to map. | ||
const auto num_keys = domains.size(); | ||
const std::vector<DomainKey> domains_vec({domains.begin(), domains.end()}); | ||
for (const auto i : c10::irange(num_keys)) { | ||
const auto& key_i = domains[i]; | ||
const auto& key_i = domains_vec[i]; | ||
// If no invalid keys found for key_i, it can be skipped. | ||
const auto invalid_key_map_it = invalid_key_map.find(key_i); | ||
if (invalid_key_map_it == invalid_key_map.end()) { | ||
|
@@ -755,7 +792,7 @@ bool ComputeAtRootDomainMapBuilder::isInvalid( | |
// If any other key in domains is identified mappable with any of | ||
// the keys in this set, the mapping with key_i is invalid. | ||
for (const auto j : c10::irange(i + 1, num_keys)) { | ||
const auto& key_j = domains[j]; | ||
const auto& key_j = domains_vec[j]; | ||
if (std::any_of( | ||
invalid_keys_for_i.begin(), | ||
invalid_keys_for_i.end(), | ||
|
@@ -1070,26 +1107,14 @@ bool ComputeAtRootDomainMapBuilder::safeToMap(const DomainKeySet& domains) { | |
if (domains.size() <= 1) { | ||
return true; | ||
} | ||
// Filter out equivalent domains | ||
std::vector<DomainKey> unique_domains; | ||
for (const auto& domain : domains) { | ||
if (std::none_of( | ||
unique_domains.begin(), | ||
unique_domains.end(), | ||
[&](const auto& unique_dom) { | ||
return root_map_.canMap(domain, unique_dom); | ||
})) { | ||
unique_domains.push_back(domain); | ||
} | ||
} | ||
|
||
// Can't map if reduction output domains would be mapped | ||
if (incompatible_domains_.isReductionOutputMapped( | ||
unique_domains, root_map_) && | ||
if (incompatible_domains_.isReductionOutputMapped(domains, root_map_) && | ||
Comment on lines
+1110
to
+1112
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the main fix. Just checking unique domains isn't sufficient. |
||
!map_through_reduction_) { | ||
return false; | ||
} | ||
// Make sure mapping these domains won't cause any invalid mapping | ||
if (isInvalid(unique_domains)) { | ||
if (isInvalid(domains)) { | ||
return false; | ||
} | ||
return true; | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is another related fix. Here, we are gathering IterDomains of all consumer tensors of a tensor with a reduction IterDomain. For this logic, we don't want to have the same tensor as its consumer.