Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 64 additions & 39 deletions torch/csrc/jit/codegen/cuda/root_domain_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,12 @@ class FindInputDomains : BackwardVisitor {
private:
FindInputDomains(TensorView* tv, const IterDomain* id)
: BackwardVisitor(false), tv_(tv) {
input_keys.insert(DomainKey(tv_->domain(), id));
input_keys_.insert(DomainKey(tv_->domain(), id));
}

DomainKeySet find() {
traverseFrom(tv_->fusion(), {tv_});
return input_keys;
return input_keys_;
}

void handle(Expr* expr) override {
Expand All @@ -261,21 +261,21 @@ class FindInputDomains : BackwardVisitor {
.mapConsumerToProducer(out_tv->domain(), in_tv->domain());
for (auto root_dom : out_tv->getRootDomain()) {
DomainKey out_key({out_tv->domain(), root_dom});
if (input_keys.find(out_key) == input_keys.end()) {
if (input_keys_.find(out_key) == input_keys_.end()) {
continue;
}
auto input_id_it = c2p.find(root_dom);
if (input_id_it == c2p.end()) {
continue;
}
DomainKey input_key(in_tv->domain(), input_id_it->second);
input_keys.insert(input_key);
input_keys_.insert(input_key);
}
}

private:
TensorView* tv_ = nullptr;
DomainKeySet input_keys;
DomainKeySet input_keys_;

public:
static DomainKeySet find(TensorView* tv, const IterDomain* id) {
Expand All @@ -297,6 +297,10 @@ void UnmappableReductionDomains::handleReductionOutput(TensorView* out_tv) {
auto use_chains = DependencyCheck::getAllUseChains(out_tv);
for (const auto& chain : use_chains) {
for (const auto& tv : ir_utils::filterByType<TensorView>(chain)) {
// Do not include the tensor itself in its consumers
if (tv == out_tv) {
continue;
}
Comment on lines +300 to +303
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is another related fix. Here, we are gathering IterDomains of all consumer tensors of a tensor with a reduction IterDomain. For this logic, we don't want to have the same tensor as its consumer.

const auto& root_domain = tv->getRootDomain();
for (const auto& id : root_domain) {
DomainKey consumer_key(tv->domain(), id);
Expand Down Expand Up @@ -339,30 +343,41 @@ void UnmappableReductionDomains::handle(WelfordOp* op) {
}

bool UnmappableReductionDomains::isReductionOutputMapped(
const std::vector<DomainKey>& consumer_domains,
const DomainKeySet& consumer_domains,
const ComputeAtRootDomainMap& root_map) const {
// Check each reduction domain if any of the consumer domains
// conflicts with it
for (const auto& kv : reduction_domains_) {
const DomainKey& reduction_domain = kv.first;
// Domains that must not be mapped with the reduction domain
const DomainKeySet& incompatible_domains = kv.second;
DomainKey consumer_domain_with_reduction;
bool reduction_found = false;
// Input domains to the reduction domain
const auto& input_keys = reduction_domain_inputs_.at(reduction_domain);
for (const DomainKey& consumer_domain : consumer_domains) {
for (const auto& input_key : input_keys) {
if (input_key == consumer_domain) {
consumer_domain_with_reduction = consumer_domain;
reduction_found = true;
break;
}
}
}
if (!reduction_found) {
// Check if any of the consumer domains is an input to the
// reduction
auto it = std::find_if(
consumer_domains.begin(),
consumer_domains.end(),
[&](const auto& consumer_domain) {
return std::find(
input_keys.begin(), input_keys.end(), consumer_domain) !=
input_keys.end();
});
// None of the consumer domains is used for the reduction
// domain. They should be safe with respect to this reduction
// domain
if (it == consumer_domains.end()) {
continue;
}
// Make sure no incompatible domains will be merged with the reduction
// domain.

// A consumer domain that is an input to the reduction domain
const DomainKey& input_to_reduction = *it;

// Check if mapping input_to_reduction with the other domains in
// consumer_domains. If there's a domain that is a consumer of the
// reduction, they must not be mapped together
for (const auto& consumer_domain : consumer_domains) {
if (consumer_domain == consumer_domain_with_reduction) {
if (consumer_domain == input_to_reduction) {
Comment on lines +346 to +380
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just additional comments and cleanup

continue;
}
if (std::any_of(
Expand All @@ -382,6 +397,27 @@ bool UnmappableReductionDomains::isReductionOutputMapped(
return false;
}

std::string UnmappableReductionDomains::toString() const {
std::stringstream ss;
ss << "Reduction-to-consumer map\n";
for (const auto& kv : reduction_domains_) {
ss << "\tReduction: " << kv.first.toString() << "\n";
for (const auto& mapped_val : kv.second) {
ss << "\t\tConsumer domain: " << mapped_val.toString() << "\n";
}
}

ss << "Reduction-to-producer map\n";
for (const auto& kv : reduction_domain_inputs_) {
ss << "\tReduction: " << kv.first.toString() << "\n";
for (const auto& mapped_val : kv.second) {
ss << "\t\tProducer domain: " << mapped_val.toString() << "\n";
}
}

return ss.str();
}

void ComputeAtRootDomainMap::build(bool map_through_reduction) {
// Make sure we start from scratch. Throw away previous results.
eq_set_.clear();
Expand Down Expand Up @@ -724,7 +760,7 @@ void ComputeAtRootDomainMapBuilder::setInvalid(
}

bool ComputeAtRootDomainMapBuilder::isInvalid(
const std::vector<DomainKey>& domains) const {
const DomainKeySet& domains) const {
// First, collect all invalid mappings for each of the keys in domains
DomainKeyMap<DomainKeySet> invalid_key_map;
for (const auto& key : domains) {
Expand All @@ -741,8 +777,9 @@ bool ComputeAtRootDomainMapBuilder::isInvalid(

// Next, check if any pair is invalid to map.
const auto num_keys = domains.size();
const std::vector<DomainKey> domains_vec({domains.begin(), domains.end()});
for (const auto i : c10::irange(num_keys)) {
const auto& key_i = domains[i];
const auto& key_i = domains_vec[i];
// If no invalid keys found for key_i, it can be skipped.
const auto invalid_key_map_it = invalid_key_map.find(key_i);
if (invalid_key_map_it == invalid_key_map.end()) {
Expand All @@ -755,7 +792,7 @@ bool ComputeAtRootDomainMapBuilder::isInvalid(
// If any other key in domains is identified mappable with any of
// the keys in this set, the mapping with key_i is invalid.
for (const auto j : c10::irange(i + 1, num_keys)) {
const auto& key_j = domains[j];
const auto& key_j = domains_vec[j];
if (std::any_of(
invalid_keys_for_i.begin(),
invalid_keys_for_i.end(),
Expand Down Expand Up @@ -1070,26 +1107,14 @@ bool ComputeAtRootDomainMapBuilder::safeToMap(const DomainKeySet& domains) {
if (domains.size() <= 1) {
return true;
}
// Filter out equivalent domains
std::vector<DomainKey> unique_domains;
for (const auto& domain : domains) {
if (std::none_of(
unique_domains.begin(),
unique_domains.end(),
[&](const auto& unique_dom) {
return root_map_.canMap(domain, unique_dom);
})) {
unique_domains.push_back(domain);
}
}

// Can't map if reduction output domains would be mapped
if (incompatible_domains_.isReductionOutputMapped(
unique_domains, root_map_) &&
if (incompatible_domains_.isReductionOutputMapped(domains, root_map_) &&
Comment on lines +1110 to +1112
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main fix. Just checking unique domains isn't sufficient.

!map_through_reduction_) {
return false;
}
// Make sure mapping these domains won't cause any invalid mapping
if (isInvalid(unique_domains)) {
if (isInvalid(domains)) {
return false;
}
return true;
Expand Down
9 changes: 7 additions & 2 deletions torch/csrc/jit/codegen/cuda/root_domain_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ class DomainKey {
return td() == other.td() && id() == other.id() &&
concreteId() == other.concreteId();
}
bool operator!=(const DomainKey& other) const {
return !(*this == other);
}

std::string toString() const;

Expand Down Expand Up @@ -183,9 +186,11 @@ class TORCH_CUDA_CU_API UnmappableReductionDomains : private IterVisitor {
//! reduction outputs within the corresponding reduction loop is not
//! possible. This routine is used to build root domain mappings.
bool isReductionOutputMapped(
const std::vector<DomainKey>& consumer_domains,
const DomainKeySet& consumer_domains,
const ComputeAtRootDomainMap& root_map) const;

std::string toString() const;

private:
using IterVisitor::handle;
void handle(ReductionOp* op) override;
Expand Down Expand Up @@ -365,7 +370,7 @@ class TORCH_CUDA_CU_API ComputeAtRootDomainMapBuilder
void setInvalid(const DomainKey& key1, const DomainKey& key2);

//! Check if no pair of domains is invalid to map
bool isInvalid(const std::vector<DomainKey>& domains) const;
bool isInvalid(const DomainKeySet& domains) const;

//! Track a pair of producer-consumer domains as potentially mappable. Inserts
//! entries into pending_map_, but does not add anything into the root_map_
Expand Down
32 changes: 32 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3741,6 +3741,38 @@ TEST_F(NVFuserTest, FusionRootMappingTrivialReduction_CUDA) {
testValidate(&fusion, outputs, aten_inputs, {t3, t4}, __LINE__, __FILE__);
}

// Repro of issue #1950
TEST_F(NVFuserTest, FusionRootMappingRepro1950_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
auto tv0 = makeSymbolicTensor(3);
auto tv1 = makeSymbolicTensor(3);
auto tv2 = makeSymbolicTensor(3);

fusion.addInput(tv0);
fusion.addInput(tv1);
fusion.addInput(tv2);

auto tv3 = set(tv0);
auto tv4 = mul(tv1, tv3);
auto tv5 = mul(tv1, tv2);
auto tv6 = mul(tv5, tv3);
auto tv7 = sum(tv6, {2});
auto tv8 = broadcast(tv7, {false, false, true});
auto tv9 = mul(tv3, tv8);

// Issue #1950 was caused by a particular traversal ordering based
// on the output tensor ordering as below
fusion.addOutput(tv9);
fusion.addOutput(tv5);
fusion.addOutput(tv4);

ComputeAtRootDomainMap root_map;
root_map.build();

checkIdMapped(root_map, tv4, tv4->axis(-1), tv9, tv9->axis(-1), false);
}

TEST_F(NVFuserTest, FusionDetectSelfMappedDomains_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
Expand Down